diff --git a/mozilla_voice_tts/bin/train_glow_tts.py b/mozilla_voice_tts/bin/train_glow_tts.py index c329e4f9..72b8a7f1 100644 --- a/mozilla_voice_tts/bin/train_glow_tts.py +++ b/mozilla_voice_tts/bin/train_glow_tts.py @@ -547,10 +547,9 @@ def main(args): # pylint: disable=redefined-outer-name model = data_depended_init(model, ap) for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) - # train_avg_loss_dict, global_step = train(model, criterion, optimizer, - # scheduler, ap, global_step, - # epoch, amp) - train_avg_loss_dict, global_step = 0, 0 + train_avg_loss_dict, global_step = train(model, criterion, optimizer, + scheduler, ap, global_step, + epoch, amp) eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = train_avg_loss_dict['avg_loss']