From 14c2381207c5972359b2af450a233730ff877ee1 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 27 Oct 2020 12:06:57 +0100 Subject: [PATCH] weight norm and torch based amp training for wavegrad --- TTS/bin/train_wavegrad.py | 90 ++++++++++------------ TTS/vocoder/configs/wavegrad_libritts.json | 4 +- TTS/vocoder/layers/wavegrad.py | 86 +++++++++++++-------- 3 files changed, 97 insertions(+), 83 deletions(-) diff --git a/TTS/bin/train_wavegrad.py b/TTS/bin/train_wavegrad.py index db961047..83e5d78b 100644 --- a/TTS/bin/train_wavegrad.py +++ b/TTS/bin/train_wavegrad.py @@ -7,10 +7,7 @@ import traceback import torch # DISTRIBUTED -try: - from apex.parallel import DistributedDataParallel as DDP_apex -except: - from torch.nn.parallel import DistributedDataParallel as DDP_th +from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.optim import Adam from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -82,7 +79,7 @@ def format_test_data(data): def train(model, criterion, optimizer, - scheduler, ap, global_step, epoch, amp): + scheduler, ap, global_step, epoch): data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) model.train() epoch_time = 0 @@ -104,6 +101,7 @@ def train(model, criterion, optimizer, model.compute_noise_level(noise_schedule['num_steps'], noise_schedule['min_val'], noise_schedule['max_val']) + scaler = torch.cuda.amp.GradScaler() for num_iter, data in enumerate(data_loader): start_time = time.time() @@ -113,39 +111,46 @@ def train(model, criterion, optimizer, global_step += 1 - # compute noisy input - if hasattr(model, 'module'): - noise, x_noisy, noise_scale = model.module.compute_y_n(x) - else: - noise, x_noisy, noise_scale = model.compute_y_n(x) + with torch.cuda.amp.autocast(): + # compute noisy input + if hasattr(model, 'module'): + noise, x_noisy, noise_scale = model.module.compute_y_n(x) + else: + noise, x_noisy, noise_scale = model.compute_y_n(x) - # forward pass - noise_hat = model(x_noisy, m, noise_scale) + # forward pass + noise_hat = model(x_noisy, m, noise_scale) - # compute losses - loss = criterion(noise, noise_hat) - # if loss.item() > 100: - # breakpoint() + # compute losses + loss = criterion(noise, noise_hat) loss_wavegrad_dict = {'wavegrad_loss':loss} - # backward pass with loss scaling + # check nan loss + if torch.isnan(loss).any(): + raise RuntimeError(f'Detected NaN loss at step {self.step}.') + optimizer.zero_grad() - if amp is not None: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() - - if c.clip_grad > 0: - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), - c.clip_grad) - optimizer.step() - - # schedule update + # schedule update if scheduler is not None: scheduler.step() + # backward pass with loss scaling + if c.mixed_precision: + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), + c.clip_grad) + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), + c.clip_grad) + optimizer.step() + + + # disconnect loss values loss_dict = dict() for key, value in loss_wavegrad_dict.items(): @@ -175,7 +180,7 @@ def train(model, criterion, optimizer, 'step_time': [step_time, 2], 'loader_time': [loader_time, 4], "current_lr": current_lr, - "grad_norm": grad_norm + "grad_norm": grad_norm.item() } c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) @@ -185,7 +190,7 @@ def train(model, criterion, optimizer, if global_step % 10 == 0: iter_stats = { "lr": current_lr, - "grad_norm": grad_norm, + "grad_norm": grad_norm.item(), "step_time": step_time } iter_stats.update(loss_dict) @@ -335,16 +340,6 @@ def main(args): # pylint: disable=redefined-outer-name # setup optimizers optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0) - # DISTRIBUTED - if c.apex_amp_level is not None: - # pylint: disable=import-outside-toplevel - from apex import amp - model.cuda() - # optimizer.cuda() - model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level) - else: - amp = None - # schedulers scheduler = None if 'lr_scheduler' in c: @@ -374,10 +369,6 @@ def main(args): # pylint: disable=redefined-outer-name model.load_state_dict(model_dict) del model_dict - # DISTRUBUTED - if amp and 'amp' in checkpoint: - amp.load_state_dict(checkpoint['amp']) - # reset lr if not countinuining training. for group in optimizer.param_groups: group['lr'] = c.lr @@ -410,7 +401,7 @@ def main(args): # pylint: disable=redefined-outer-name c_logger.print_epoch_start(epoch, c.epochs) _, global_step = train(model, criterion, optimizer, scheduler, ap, global_step, - epoch, amp) + epoch) eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) @@ -426,8 +417,7 @@ def main(args): # pylint: disable=redefined-outer-name global_step, epoch, OUT_PATH, - model_losses=eval_avg_loss_dict, - amp_state_dict=amp.state_dict() if amp else None) + model_losses=eval_avg_loss_dict) if __name__ == '__main__': @@ -481,8 +471,8 @@ if __name__ == '__main__': _ = os.path.dirname(os.path.realpath(__file__)) # DISTRIBUTED - if c.apex_amp_level is not None: - print(" > apex AMP level: ", c.apex_amp_level) + if c.mixed_precision: + print(" > Mixed precision is enabled") OUT_PATH = args.continue_path if args.continue_path == '': diff --git a/TTS/vocoder/configs/wavegrad_libritts.json b/TTS/vocoder/configs/wavegrad_libritts.json index 64958da2..5720a482 100644 --- a/TTS/vocoder/configs/wavegrad_libritts.json +++ b/TTS/vocoder/configs/wavegrad_libritts.json @@ -34,7 +34,7 @@ }, // DISTRIBUTED TRAINING - "apex_amp_level": "O1", // APEX amp optimization level. "O1" is currently supported. + "mixed_precision": true, // enable torch mixed precision training (true, false) "distributed":{ "backend": "nccl", "url": "tcp:\/\/localhost:54322" @@ -98,7 +98,7 @@ // TENSORBOARD and LOGGING "print_step": 50, // Number of steps to log traning on console. "print_eval": false, // If True, it prints loss values for each step in eval run. - "save_step": 10000, // Number of training steps expected to plot training stats on TB and save model checkpoints. + "save_step": 5000, // Number of training steps expected to plot training stats on TB and save model checkpoints. "checkpoint": true, // If true, it saves checkpoints per "save_step" "tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. diff --git a/TTS/vocoder/layers/wavegrad.py b/TTS/vocoder/layers/wavegrad.py index 0b9dde48..a72f2837 100644 --- a/TTS/vocoder/layers/wavegrad.py +++ b/TTS/vocoder/layers/wavegrad.py @@ -2,6 +2,7 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn.utils import weight_norm from math import log as ln @@ -13,36 +14,59 @@ class Conv1d(nn.Conv1d): nn.init.zeros_(self.bias) +# class PositionalEncoding(nn.Module): +# def __init__(self, n_channels): +# super().__init__() +# self.n_channels = n_channels +# self.length = n_channels // 2 +# assert n_channels % 2 == 0 + +# def forward(self, x, noise_level): +# """ +# Shapes: +# x: B x C x T +# noise_level: B +# """ +# return (x + self.encoding(noise_level)[:, :, None]) + +# def encoding(self, noise_level): +# step = torch.arange( +# self.length, dtype=noise_level.dtype, device=noise_level.device) / self.length +# encoding = noise_level.unsqueeze(1) * torch.exp( +# -ln(1e4) * step.unsqueeze(0)) +# encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1) +# return encoding + + class PositionalEncoding(nn.Module): - def __init__(self, n_channels): + def __init__(self, n_channels, max_len=10000): super().__init__() self.n_channels = n_channels - self.length = n_channels // 2 - assert n_channels % 2 == 0 + self.max_len = max_len + self.C = 5000 + self.pe = torch.zeros(0, 0) def forward(self, x, noise_level): - """ - Shapes: - x: B x C x T - noise_level: B - """ - return (x + self.encoding(noise_level)[:, :, None]) + if x.shape[2] > self.pe.shape[1]: + self.init_pe_matrix(x.shape[1] ,x.shape[2], x) + return x + noise_level[..., None, None] + self.pe[:, :x.size(2)].repeat(x.shape[0], 1, 1) / self.C - def encoding(self, noise_level): - step = torch.arange( - self.length, dtype=noise_level.dtype, device=noise_level.device) / self.length - encoding = noise_level.unsqueeze(1) * torch.exp( - -ln(1e4) * step.unsqueeze(0)) - encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1) - return encoding + def init_pe_matrix(self, n_channels, max_len, x): + pe = torch.zeros(max_len, n_channels) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.pow(10000, torch.arange(0, n_channels, 2).float() / n_channels) + + pe[:, 0::2] = torch.sin(position / div_term) + pe[:, 1::2] = torch.cos(position / div_term) + self.pe = pe.transpose(0, 1).to(x) class FiLM(nn.Module): def __init__(self, input_size, output_size): super().__init__() self.encoding = PositionalEncoding(input_size) - self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1) - self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1) + self.input_conv = weight_norm(nn.Conv1d(input_size, input_size, 3, padding=1)) + self.output_conv = weight_norm(nn.Conv1d(input_size, output_size * 2, 3, padding=1)) self.ini_parameters() def ini_parameters(self): @@ -72,30 +96,30 @@ class UBlock(nn.Module): assert len(dilation) == 4 self.factor = factor - self.block1 = Conv1d(input_size, hidden_size, 1) + self.block1 = weight_norm(Conv1d(input_size, hidden_size, 1)) self.block2 = nn.ModuleList([ - Conv1d(input_size, + weight_norm(Conv1d(input_size, hidden_size, 3, dilation=dilation[0], - padding=dilation[0]), - Conv1d(hidden_size, + padding=dilation[0])), + weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], - padding=dilation[1]) + padding=dilation[1])) ]) self.block3 = nn.ModuleList([ - Conv1d(hidden_size, + weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], - padding=dilation[2]), - Conv1d(hidden_size, + padding=dilation[2])), + weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], - padding=dilation[3]) + padding=dilation[3])) ]) def forward(self, x, shift, scale): @@ -129,11 +153,11 @@ class DBlock(nn.Module): def __init__(self, input_size, hidden_size, factor): super().__init__() self.factor = factor - self.residual_dense = Conv1d(input_size, hidden_size, 1) + self.residual_dense = weight_norm(Conv1d(input_size, hidden_size, 1)) self.conv = nn.ModuleList([ - Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), - Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), - Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), + weight_norm(Conv1d(input_size, hidden_size, 3, dilation=1, padding=1)), + weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2)), + weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4)), ]) def forward(self, x):