diff --git a/TTS/bin/train_align_tts.py b/TTS/bin/train_align_tts.py index 0260db76..1b3e7d52 100644 --- a/TTS/bin/train_align_tts.py +++ b/TTS/bin/train_align_tts.py @@ -309,19 +309,25 @@ if __name__ == '__main__': # forward pass model with torch.cuda.amp.autocast(enabled=c.mixed_precision): - decoder_output, dur_output, dur_mas_output, alignments, mu, log_sigma, logp_max_path = model.forward( + decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward( text_input, text_lengths, mel_targets, mel_lengths, - g=speaker_c) + g=speaker_c, + phase=training_phase) + + # compute loss + loss_dict = criterion(logp, + decoder_output, + mel_targets, + mel_lengths, + dur_output, + dur_mas_output, + text_lengths, + global_step, + phase=training_phase) - # compute loss - loss_dict = criterion(mu, log_sigma, logp_max_path, - decoder_output, mel_targets, - mel_lengths, dur_output, - dur_mas_output, text_lengths, - global_step, training_phase) # step time step_time = time.time() - start_time