mirror of https://github.com/coqui-ai/TTS.git
train wavegrad updates
This commit is contained in:
parent
670f44aa18
commit
c8a4c771a8
|
@ -7,7 +7,9 @@ import traceback
|
|||
|
||||
import torch
|
||||
# DISTRIBUTED
|
||||
try:
|
||||
from apex.parallel import DistributedDataParallel as DDP_apex
|
||||
except:
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -61,6 +63,7 @@ def setup_loader(ap, is_val=False, verbose=False):
|
|||
def format_data(data):
|
||||
# return a whole audio segment
|
||||
m, x = data
|
||||
x = x.unsqueeze(1)
|
||||
if use_cuda:
|
||||
m = m.cuda(non_blocking=True)
|
||||
x = x.cuda(non_blocking=True)
|
||||
|
@ -70,8 +73,8 @@ def format_data(data):
|
|||
def format_test_data(data):
|
||||
# return a whole audio segment
|
||||
m, x = data
|
||||
m = m.unsqueeze(0)
|
||||
x = x.unsqueeze(0)
|
||||
m = m[None, ...]
|
||||
x = x[None, None, ...]
|
||||
if use_cuda:
|
||||
m = m.cuda(non_blocking=True)
|
||||
x = x.cuda(non_blocking=True)
|
||||
|
@ -94,11 +97,11 @@ def train(model, criterion, optimizer,
|
|||
# setup noise schedule
|
||||
noise_schedule = c['train_noise_schedule']
|
||||
if hasattr(model, 'module'):
|
||||
model.module.init_noise_schedule(noise_schedule['num_steps'],
|
||||
model.module.compute_noise_level(noise_schedule['num_steps'],
|
||||
noise_schedule['min_val'],
|
||||
noise_schedule['max_val'])
|
||||
else:
|
||||
model.init_noise_schedule(noise_schedule['num_steps'],
|
||||
model.compute_noise_level(noise_schedule['num_steps'],
|
||||
noise_schedule['min_val'],
|
||||
noise_schedule['max_val'])
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
|
@ -112,15 +115,17 @@ def train(model, criterion, optimizer,
|
|||
|
||||
# compute noisy input
|
||||
if hasattr(model, 'module'):
|
||||
noise, x_noisy, noise_scale = model.module.compute_noisy_x(x)
|
||||
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
||||
else:
|
||||
noise, x_noisy, noise_scale = model.compute_noisy_x(x)
|
||||
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
||||
|
||||
# forward pass
|
||||
noise_hat = model(x_noisy, m, noise_scale)
|
||||
|
||||
# compute losses
|
||||
loss = criterion(noise, noise_hat)
|
||||
# if loss.item() > 100:
|
||||
# breakpoint()
|
||||
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
||||
|
||||
# backward pass with loss scaling
|
||||
|
@ -212,8 +217,8 @@ def train(model, criterion, optimizer,
|
|||
if args.rank == 0:
|
||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||
# TODO: plot model stats
|
||||
# if c.tb_model_param_stats:
|
||||
# tb_logger.tb_model_weights(model, global_step)
|
||||
if c.tb_model_param_stats:
|
||||
tb_logger.tb_model_weights(model, global_step)
|
||||
return keep_avg.avg_values, global_step
|
||||
|
||||
|
||||
|
@ -236,9 +241,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
|
||||
# compute noisy input
|
||||
if hasattr(model, 'module'):
|
||||
noise, x_noisy, noise_scale = model.module.compute_noisy_x(x)
|
||||
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
||||
else:
|
||||
noise, x_noisy, noise_scale = model.compute_noisy_x(x)
|
||||
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
||||
|
||||
|
||||
# forward pass
|
||||
|
@ -272,19 +277,20 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
data_loader.dataset.return_segments = False
|
||||
samples = data_loader.dataset.load_test_samples(1)
|
||||
m, x = format_test_data(samples[0])
|
||||
|
||||
# setup noise schedule and inference
|
||||
noise_schedule = c['test_noise_schedule']
|
||||
if hasattr(model, 'module'):
|
||||
model.module.init_noise_schedule(noise_schedule['num_steps'],
|
||||
model.module.compute_noise_level(noise_schedule['num_steps'],
|
||||
noise_schedule['min_val'],
|
||||
noise_schedule['max_val'])
|
||||
# compute voice
|
||||
x_pred = model.module.inference(m)
|
||||
else:
|
||||
model.init_noise_schedule(noise_schedule['num_steps'],
|
||||
model.compute_noise_level(noise_schedule['num_steps'],
|
||||
noise_schedule['min_val'],
|
||||
noise_schedule['max_val'])
|
||||
# compute voice
|
||||
|
@ -300,6 +306,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
c.audio["sample_rate"])
|
||||
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
data_loader.dataset.return_segments = True
|
||||
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
@ -333,6 +340,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
# pylint: disable=import-outside-toplevel
|
||||
from apex import amp
|
||||
model.cuda()
|
||||
# optimizer.cuda()
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
|
||||
else:
|
||||
amp = None
|
||||
|
|
Loading…
Reference in New Issue