diff --git a/mozilla_voice_tts/bin/train_tts.py b/mozilla_voice_tts/bin/train_tts.py index 2b6cbfd0..1c27b5a4 100644 --- a/mozilla_voice_tts/bin/train_tts.py +++ b/mozilla_voice_tts/bin/train_tts.py @@ -441,7 +441,7 @@ def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=None): test_figures = {} print(" | > Synthesizing test sentences") speaker_id = 0 if c.use_speaker_embedding else None - style_wav = c.get("style_wav_for_test") + style_wav = c.get("gst_style_input") for idx, test_sentence in enumerate(test_sentences): try: wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis( @@ -624,7 +624,7 @@ def main(args): # pylint: disable=redefined-outer-name train_avg_loss_dict, global_step = train(model, criterion, optimizer, optimizer_st, scheduler, ap, global_step, epoch, amp, speaker_mapping) - eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) + eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = train_avg_loss_dict['avg_postnet_loss'] if c.run_eval: