diff --git a/train.py b/train.py index 612ee8a6..13bda5ef 100644 --- a/train.py +++ b/train.py @@ -331,13 +331,14 @@ def evaluate(model, criterion, ap, global_step, epoch): assert mel_input.shape[1] % model.decoder.r == 0 # forward pass model - if c.bidirectional_decoder: + if c.bidirectional_decoder or c.double_decoder_consistency: decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( text_input, text_lengths, mel_input, speaker_ids=speaker_ids) else: decoder_output, postnet_output, alignments, stop_tokens = model( text_input, text_lengths, mel_input, speaker_ids=speaker_ids) decoder_backward_output = None + alignments_backward = None # set the alignment lengths wrt reduction factor for guided attention if mel_lengths.max() % model.decoder.r != 0: @@ -349,7 +350,8 @@ def evaluate(model, criterion, ap, global_step, epoch): loss_dict = criterion(postnet_output, decoder_output, mel_input, linear_input, stop_tokens, stop_targets, mel_lengths, decoder_backward_output, - alignments, alignment_lengths, text_lengths) + alignments, alignment_lengths, alignments_backward, + text_lengths) if c.bidirectional_decoder: keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_b_loss'].item(), 'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()})