train.py update imports for utils refactoring

This commit is contained in:
erogol 2020-05-12 13:46:58 +02:00
parent 2d9dcd60ba
commit c0c3c6e331
1 changed files with 12 additions and 11 deletions

View File

@ -14,12 +14,13 @@ from distribute import (DistributedSampler, apply_gradient_allreduce,
init_distributed, reduce_tensor) init_distributed, reduce_tensor)
from TTS.layers.losses import TacotronLoss from TTS.layers.losses import TacotronLoss
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ( from TTS.utils.generic_utils import (count_parameters, create_experiment_folder, remove_experiment_folder,
NoamLR, check_update, count_parameters, create_experiment_folder, get_git_branch, set_init_dict,
get_git_branch, load_config, remove_experiment_folder, save_best_model, setup_model, KeepAverage, check_config)
save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file, from TTS.utils.io import (save_best_model, save_checkpoint,
setup_model, gradual_training_scheduler, KeepAverage, load_config, copy_config_file)
set_weight_decay, check_config) from TTS.utils.training import (NoamLR, check_update, adam_weight_decay,
gradual_training_scheduler, set_weight_decay)
from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.console_logger import ConsoleLogger from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \ from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
@ -251,9 +252,9 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
if global_step % c.save_step == 0: if global_step % c.save_step == 0:
if c.checkpoint: if c.checkpoint:
# save model # save model
save_checkpoint(model, optimizer, optimizer_st, save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
loss_dict['postnet_loss'].item(), OUT_PATH, global_step, optimizer_st=optimizer_st,
epoch) model_loss=loss_dict['postnet_loss'].item())
# Diagnostic visualizations # Diagnostic visualizations
const_spec = postnet_output[0].data.cpu().numpy() const_spec = postnet_output[0].data.cpu().numpy()
@ -596,8 +597,8 @@ def main(args): # pylint: disable=redefined-outer-name
target_loss = train_avg_loss_dict['avg_postnet_loss'] target_loss = train_avg_loss_dict['avg_postnet_loss']
if c.run_eval: if c.run_eval:
target_loss = eval_avg_loss_dict['avg_postnet_loss'] target_loss = eval_avg_loss_dict['avg_postnet_loss']
best_loss = save_best_model(model, optimizer, target_loss, best_loss, best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
OUT_PATH, global_step, epoch) OUT_PATH)
if __name__ == '__main__': if __name__ == '__main__':