From 4aaf57b50b7d806dfddc22d49d8ffc0ff3d9f137 Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 4 Jun 2020 14:30:12 +0200 Subject: [PATCH] train.py update for DDC --- train.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 5a345b59..612ee8a6 100644 --- a/train.py +++ b/train.py @@ -158,13 +158,14 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, optimizer_st.zero_grad() # forward pass model - if c.bidirectional_decoder: + if c.bidirectional_decoder or c.double_decoder_consistency: decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids) + text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids) else: decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids) + text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids) decoder_backward_output = None + alignments_backward = None # set the alignment lengths wrt reduction factor for guided attention if mel_lengths.max() % model.decoder.r != 0: @@ -176,7 +177,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, loss_dict = criterion(postnet_output, decoder_output, mel_input, linear_input, stop_tokens, stop_targets, mel_lengths, decoder_backward_output, - alignments, alignment_lengths, text_lengths) + alignments, alignment_lengths, alignments_backward, + text_lengths) if c.bidirectional_decoder: keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_backward_loss'].item(), 'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()})