train wavegrad updates

This commit is contained in:
erogol 2020-10-26 16:46:26 +01:00
parent 670f44aa18
commit c8a4c771a8
1 changed files with 22 additions and 14 deletions

View File

@ -7,8 +7,10 @@ import traceback
import torch import torch
# DISTRIBUTED # DISTRIBUTED
from apex.parallel import DistributedDataParallel as DDP_apex try:
from torch.nn.parallel import DistributedDataParallel as DDP_th from apex.parallel import DistributedDataParallel as DDP_apex
except:
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
@ -61,6 +63,7 @@ def setup_loader(ap, is_val=False, verbose=False):
def format_data(data): def format_data(data):
# return a whole audio segment # return a whole audio segment
m, x = data m, x = data
x = x.unsqueeze(1)
if use_cuda: if use_cuda:
m = m.cuda(non_blocking=True) m = m.cuda(non_blocking=True)
x = x.cuda(non_blocking=True) x = x.cuda(non_blocking=True)
@ -70,8 +73,8 @@ def format_data(data):
def format_test_data(data): def format_test_data(data):
# return a whole audio segment # return a whole audio segment
m, x = data m, x = data
m = m.unsqueeze(0) m = m[None, ...]
x = x.unsqueeze(0) x = x[None, None, ...]
if use_cuda: if use_cuda:
m = m.cuda(non_blocking=True) m = m.cuda(non_blocking=True)
x = x.cuda(non_blocking=True) x = x.cuda(non_blocking=True)
@ -94,11 +97,11 @@ def train(model, criterion, optimizer,
# setup noise schedule # setup noise schedule
noise_schedule = c['train_noise_schedule'] noise_schedule = c['train_noise_schedule']
if hasattr(model, 'module'): if hasattr(model, 'module'):
model.module.init_noise_schedule(noise_schedule['num_steps'], model.module.compute_noise_level(noise_schedule['num_steps'],
noise_schedule['min_val'], noise_schedule['min_val'],
noise_schedule['max_val']) noise_schedule['max_val'])
else: else:
model.init_noise_schedule(noise_schedule['num_steps'], model.compute_noise_level(noise_schedule['num_steps'],
noise_schedule['min_val'], noise_schedule['min_val'],
noise_schedule['max_val']) noise_schedule['max_val'])
for num_iter, data in enumerate(data_loader): for num_iter, data in enumerate(data_loader):
@ -112,15 +115,17 @@ def train(model, criterion, optimizer,
# compute noisy input # compute noisy input
if hasattr(model, 'module'): if hasattr(model, 'module'):
noise, x_noisy, noise_scale = model.module.compute_noisy_x(x) noise, x_noisy, noise_scale = model.module.compute_y_n(x)
else: else:
noise, x_noisy, noise_scale = model.compute_noisy_x(x) noise, x_noisy, noise_scale = model.compute_y_n(x)
# forward pass # forward pass
noise_hat = model(x_noisy, m, noise_scale) noise_hat = model(x_noisy, m, noise_scale)
# compute losses # compute losses
loss = criterion(noise, noise_hat) loss = criterion(noise, noise_hat)
# if loss.item() > 100:
# breakpoint()
loss_wavegrad_dict = {'wavegrad_loss':loss} loss_wavegrad_dict = {'wavegrad_loss':loss}
# backward pass with loss scaling # backward pass with loss scaling
@ -212,8 +217,8 @@ def train(model, criterion, optimizer,
if args.rank == 0: if args.rank == 0:
tb_logger.tb_train_epoch_stats(global_step, epoch_stats) tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
# TODO: plot model stats # TODO: plot model stats
# if c.tb_model_param_stats: if c.tb_model_param_stats:
# tb_logger.tb_model_weights(model, global_step) tb_logger.tb_model_weights(model, global_step)
return keep_avg.avg_values, global_step return keep_avg.avg_values, global_step
@ -236,9 +241,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
# compute noisy input # compute noisy input
if hasattr(model, 'module'): if hasattr(model, 'module'):
noise, x_noisy, noise_scale = model.module.compute_noisy_x(x) noise, x_noisy, noise_scale = model.module.compute_y_n(x)
else: else:
noise, x_noisy, noise_scale = model.compute_noisy_x(x) noise, x_noisy, noise_scale = model.compute_y_n(x)
# forward pass # forward pass
@ -272,19 +277,20 @@ def evaluate(model, criterion, ap, global_step, epoch):
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
if args.rank == 0: if args.rank == 0:
data_loader.dataset.return_segments = False
samples = data_loader.dataset.load_test_samples(1) samples = data_loader.dataset.load_test_samples(1)
m, x = format_test_data(samples[0]) m, x = format_test_data(samples[0])
# setup noise schedule and inference # setup noise schedule and inference
noise_schedule = c['test_noise_schedule'] noise_schedule = c['test_noise_schedule']
if hasattr(model, 'module'): if hasattr(model, 'module'):
model.module.init_noise_schedule(noise_schedule['num_steps'], model.module.compute_noise_level(noise_schedule['num_steps'],
noise_schedule['min_val'], noise_schedule['min_val'],
noise_schedule['max_val']) noise_schedule['max_val'])
# compute voice # compute voice
x_pred = model.module.inference(m) x_pred = model.module.inference(m)
else: else:
model.init_noise_schedule(noise_schedule['num_steps'], model.compute_noise_level(noise_schedule['num_steps'],
noise_schedule['min_val'], noise_schedule['min_val'],
noise_schedule['max_val']) noise_schedule['max_val'])
# compute voice # compute voice
@ -300,6 +306,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
c.audio["sample_rate"]) c.audio["sample_rate"])
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
data_loader.dataset.return_segments = True
return keep_avg.avg_values return keep_avg.avg_values
@ -333,6 +340,7 @@ def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from apex import amp from apex import amp
model.cuda() model.cuda()
# optimizer.cuda()
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level) model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
else: else:
amp = None amp = None