From b4c2cf80f2a1fb3f0f0fee2c14e609edbb663dc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Mar 2021 12:31:21 +0100 Subject: [PATCH] fix eval iter --- TTS/bin/train_align_tts.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/TTS/bin/train_align_tts.py b/TTS/bin/train_align_tts.py index 0260db76..1b3e7d52 100644 --- a/TTS/bin/train_align_tts.py +++ b/TTS/bin/train_align_tts.py @@ -309,19 +309,25 @@ if __name__ == '__main__': # forward pass model with torch.cuda.amp.autocast(enabled=c.mixed_precision): - decoder_output, dur_output, dur_mas_output, alignments, mu, log_sigma, logp_max_path = model.forward( + decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward( text_input, text_lengths, mel_targets, mel_lengths, - g=speaker_c) + g=speaker_c, + phase=training_phase) + + # compute loss + loss_dict = criterion(logp, + decoder_output, + mel_targets, + mel_lengths, + dur_output, + dur_mas_output, + text_lengths, + global_step, + phase=training_phase) - # compute loss - loss_dict = criterion(mu, log_sigma, logp_max_path, - decoder_output, mel_targets, - mel_lengths, dur_output, - dur_mas_output, text_lengths, - global_step, training_phase) # step time step_time = time.time() - start_time