From 5a59467f34ea1b33a10ea8cc205c6b2bbb1bc158 Mon Sep 17 00:00:00 2001 From: erogol Date: Sat, 14 Nov 2020 13:00:35 +0100 Subject: [PATCH] scaler fix for wavegrad and wavernn. Save and load scaler --- TTS/bin/train_vocoder_wavegrad.py | 22 ++++++++++------------ TTS/bin/train_vocoder_wavernn.py | 15 ++++++++++----- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/TTS/bin/train_vocoder_wavegrad.py b/TTS/bin/train_vocoder_wavegrad.py index e5caa9e2..261be3fa 100644 --- a/TTS/bin/train_vocoder_wavegrad.py +++ b/TTS/bin/train_vocoder_wavegrad.py @@ -4,6 +4,7 @@ import os import sys import time import traceback +import numpy as np import torch # DISTRIBUTED @@ -94,14 +95,11 @@ def train(model, criterion, optimizer, c_logger.print_train_start() # setup noise schedule noise_schedule = c['train_noise_schedule'] + betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps']) if hasattr(model, 'module'): - model.module.compute_noise_level(noise_schedule['num_steps'], - noise_schedule['min_val'], - noise_schedule['max_val']) + model.module.compute_noise_level(betas) else: - model.compute_noise_level(noise_schedule['num_steps'], - noise_schedule['min_val'], - noise_schedule['max_val']) + model.compute_noise_level(betas) for num_iter, data in enumerate(data_loader): start_time = time.time() @@ -287,16 +285,13 @@ def evaluate(model, criterion, ap, global_step, epoch): # setup noise schedule and inference noise_schedule = c['test_noise_schedule'] + betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps']) if hasattr(model, 'module'): - model.module.compute_noise_level(noise_schedule['num_steps'], - noise_schedule['min_val'], - noise_schedule['max_val']) + model.module.compute_noise_level(betas) # compute voice x_pred = model.module.inference(m) else: - model.compute_noise_level(noise_schedule['num_steps'], - noise_schedule['min_val'], - noise_schedule['max_val']) + model.compute_noise_level(betas) # compute voice x_pred = model.inference(m) @@ -363,6 +358,9 @@ def main(args): # pylint: disable=redefined-outer-name scheduler.load_state_dict(checkpoint['scheduler']) # NOTE: Not sure if necessary scheduler.optimizer = optimizer + if "scaler" in checkpoint and c.mixed_precision: + print(" > Restoring AMP Scaler...") + scaler.load_state_dict(checkpoint["scaler"]) except RuntimeError: # retore only matching layers. print(" > Partial model initialization...") diff --git a/TTS/bin/train_vocoder_wavernn.py b/TTS/bin/train_vocoder_wavernn.py index 8d563217..3f4c0fb3 100644 --- a/TTS/bin/train_vocoder_wavernn.py +++ b/TTS/bin/train_vocoder_wavernn.py @@ -81,7 +81,7 @@ def format_data(data): return x_input, mels, y_coarse -def train(model, optimizer, criterion, scheduler, ap, global_step, epoch): +def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch): # create train loader data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) model.train() @@ -94,7 +94,6 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch): batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() c_logger.print_train_start() - scaler = torch.cuda.amp.GradScaler() # train loop for num_iter, data in enumerate(data_loader): start_time = time.time() @@ -192,6 +191,7 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch): epoch, OUT_PATH, model_losses=loss_dict, + scaler=scaler.state_dict() if c.mixed_precision else None ) # synthesize a full voice @@ -352,6 +352,9 @@ def main(args): # pylint: disable=redefined-outer-name # setup model model_wavernn = setup_wavernn(c) + # setup amp scaler + scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None + # define train functions if c.mode == "mold": criterion = discretized_mix_logistic_loss @@ -387,8 +390,9 @@ def main(args): # pylint: disable=redefined-outer-name print(" > Restoring Generator LR Scheduler...") scheduler.load_state_dict(checkpoint["scheduler"]) scheduler.optimizer = optimizer - # TODO: fix resetting restored optimizer lr - # optimizer.load_state_dict(checkpoint["optimizer"]) + if "scaler" in checkpoint and c.mixed_precision: + print(" > Restoring AMP Scaler...") + scaler.load_state_dict(checkpoint["scaler"]) except RuntimeError: # retore only matching layers. print(" > Partial model initialization...") @@ -416,7 +420,7 @@ def main(args): # pylint: disable=redefined-outer-name for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) _, global_step = train(model_wavernn, optimizer, - criterion, scheduler, ap, global_step, epoch) + criterion, scheduler, scaler, ap, global_step, epoch) eval_avg_loss_dict = evaluate( model_wavernn, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) @@ -434,6 +438,7 @@ def main(args): # pylint: disable=redefined-outer-name epoch, OUT_PATH, model_losses=eval_avg_loss_dict, + scaler=scaler.state_dict() if c.mixed_precision else None )