mirror of https://github.com/coqui-ai/TTS.git
train wavegrad updates
This commit is contained in:
parent
670f44aa18
commit
c8a4c771a8
|
@ -7,8 +7,10 @@ import traceback
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
# DISTRIBUTED
|
# DISTRIBUTED
|
||||||
from apex.parallel import DistributedDataParallel as DDP_apex
|
try:
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
from apex.parallel import DistributedDataParallel as DDP_apex
|
||||||
|
except:
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
@ -61,6 +63,7 @@ def setup_loader(ap, is_val=False, verbose=False):
|
||||||
def format_data(data):
|
def format_data(data):
|
||||||
# return a whole audio segment
|
# return a whole audio segment
|
||||||
m, x = data
|
m, x = data
|
||||||
|
x = x.unsqueeze(1)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
m = m.cuda(non_blocking=True)
|
m = m.cuda(non_blocking=True)
|
||||||
x = x.cuda(non_blocking=True)
|
x = x.cuda(non_blocking=True)
|
||||||
|
@ -70,8 +73,8 @@ def format_data(data):
|
||||||
def format_test_data(data):
|
def format_test_data(data):
|
||||||
# return a whole audio segment
|
# return a whole audio segment
|
||||||
m, x = data
|
m, x = data
|
||||||
m = m.unsqueeze(0)
|
m = m[None, ...]
|
||||||
x = x.unsqueeze(0)
|
x = x[None, None, ...]
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
m = m.cuda(non_blocking=True)
|
m = m.cuda(non_blocking=True)
|
||||||
x = x.cuda(non_blocking=True)
|
x = x.cuda(non_blocking=True)
|
||||||
|
@ -94,11 +97,11 @@ def train(model, criterion, optimizer,
|
||||||
# setup noise schedule
|
# setup noise schedule
|
||||||
noise_schedule = c['train_noise_schedule']
|
noise_schedule = c['train_noise_schedule']
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model.module.init_noise_schedule(noise_schedule['num_steps'],
|
model.module.compute_noise_level(noise_schedule['num_steps'],
|
||||||
noise_schedule['min_val'],
|
noise_schedule['min_val'],
|
||||||
noise_schedule['max_val'])
|
noise_schedule['max_val'])
|
||||||
else:
|
else:
|
||||||
model.init_noise_schedule(noise_schedule['num_steps'],
|
model.compute_noise_level(noise_schedule['num_steps'],
|
||||||
noise_schedule['min_val'],
|
noise_schedule['min_val'],
|
||||||
noise_schedule['max_val'])
|
noise_schedule['max_val'])
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
|
@ -112,15 +115,17 @@ def train(model, criterion, optimizer,
|
||||||
|
|
||||||
# compute noisy input
|
# compute noisy input
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
noise, x_noisy, noise_scale = model.module.compute_noisy_x(x)
|
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
||||||
else:
|
else:
|
||||||
noise, x_noisy, noise_scale = model.compute_noisy_x(x)
|
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
noise_hat = model(x_noisy, m, noise_scale)
|
noise_hat = model(x_noisy, m, noise_scale)
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
loss = criterion(noise, noise_hat)
|
loss = criterion(noise, noise_hat)
|
||||||
|
# if loss.item() > 100:
|
||||||
|
# breakpoint()
|
||||||
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
||||||
|
|
||||||
# backward pass with loss scaling
|
# backward pass with loss scaling
|
||||||
|
@ -212,8 +217,8 @@ def train(model, criterion, optimizer,
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||||
# TODO: plot model stats
|
# TODO: plot model stats
|
||||||
# if c.tb_model_param_stats:
|
if c.tb_model_param_stats:
|
||||||
# tb_logger.tb_model_weights(model, global_step)
|
tb_logger.tb_model_weights(model, global_step)
|
||||||
return keep_avg.avg_values, global_step
|
return keep_avg.avg_values, global_step
|
||||||
|
|
||||||
|
|
||||||
|
@ -236,9 +241,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
|
|
||||||
# compute noisy input
|
# compute noisy input
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
noise, x_noisy, noise_scale = model.module.compute_noisy_x(x)
|
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
||||||
else:
|
else:
|
||||||
noise, x_noisy, noise_scale = model.compute_noisy_x(x)
|
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
||||||
|
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
|
@ -272,19 +277,20 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||||
|
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
|
data_loader.dataset.return_segments = False
|
||||||
samples = data_loader.dataset.load_test_samples(1)
|
samples = data_loader.dataset.load_test_samples(1)
|
||||||
m, x = format_test_data(samples[0])
|
m, x = format_test_data(samples[0])
|
||||||
|
|
||||||
# setup noise schedule and inference
|
# setup noise schedule and inference
|
||||||
noise_schedule = c['test_noise_schedule']
|
noise_schedule = c['test_noise_schedule']
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model.module.init_noise_schedule(noise_schedule['num_steps'],
|
model.module.compute_noise_level(noise_schedule['num_steps'],
|
||||||
noise_schedule['min_val'],
|
noise_schedule['min_val'],
|
||||||
noise_schedule['max_val'])
|
noise_schedule['max_val'])
|
||||||
# compute voice
|
# compute voice
|
||||||
x_pred = model.module.inference(m)
|
x_pred = model.module.inference(m)
|
||||||
else:
|
else:
|
||||||
model.init_noise_schedule(noise_schedule['num_steps'],
|
model.compute_noise_level(noise_schedule['num_steps'],
|
||||||
noise_schedule['min_val'],
|
noise_schedule['min_val'],
|
||||||
noise_schedule['max_val'])
|
noise_schedule['max_val'])
|
||||||
# compute voice
|
# compute voice
|
||||||
|
@ -300,6 +306,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
c.audio["sample_rate"])
|
c.audio["sample_rate"])
|
||||||
|
|
||||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||||
|
data_loader.dataset.return_segments = True
|
||||||
|
|
||||||
return keep_avg.avg_values
|
return keep_avg.avg_values
|
||||||
|
|
||||||
|
@ -333,6 +340,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
# pylint: disable=import-outside-toplevel
|
# pylint: disable=import-outside-toplevel
|
||||||
from apex import amp
|
from apex import amp
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
# optimizer.cuda()
|
||||||
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
|
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
|
||||||
else:
|
else:
|
||||||
amp = None
|
amp = None
|
||||||
|
|
Loading…
Reference in New Issue