From c0c3c6e3311a0e23ef3eeb92afff920a3b7be45e Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 12 May 2020 13:46:58 +0200 Subject: [PATCH] train.py update imports for utils refactoring --- train.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/train.py b/train.py index 84648636..3eec0107 100644 --- a/train.py +++ b/train.py @@ -14,12 +14,13 @@ from distribute import (DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor) from TTS.layers.losses import TacotronLoss from TTS.utils.audio import AudioProcessor -from TTS.utils.generic_utils import ( - NoamLR, check_update, count_parameters, create_experiment_folder, - get_git_branch, load_config, remove_experiment_folder, save_best_model, - save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file, - setup_model, gradual_training_scheduler, KeepAverage, - set_weight_decay, check_config) +from TTS.utils.generic_utils import (count_parameters, create_experiment_folder, remove_experiment_folder, + get_git_branch, set_init_dict, + setup_model, KeepAverage, check_config) +from TTS.utils.io import (save_best_model, save_checkpoint, + load_config, copy_config_file) +from TTS.utils.training import (NoamLR, check_update, adam_weight_decay, + gradual_training_scheduler, set_weight_decay) from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.console_logger import ConsoleLogger from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \ @@ -251,9 +252,9 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, if global_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model, optimizer, optimizer_st, - loss_dict['postnet_loss'].item(), OUT_PATH, global_step, - epoch) + save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH, + optimizer_st=optimizer_st, + model_loss=loss_dict['postnet_loss'].item()) # Diagnostic visualizations const_spec = postnet_output[0].data.cpu().numpy() @@ -596,8 +597,8 @@ def main(args): # pylint: disable=redefined-outer-name target_loss = train_avg_loss_dict['avg_postnet_loss'] if c.run_eval: target_loss = eval_avg_loss_dict['avg_postnet_loss'] - best_loss = save_best_model(model, optimizer, target_loss, best_loss, - OUT_PATH, global_step, epoch) + best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r, + OUT_PATH) if __name__ == '__main__':