bug fix for DDC model eval run

This commit is contained in:
erogol 2020-06-05 14:47:25 +02:00
parent d00b91710a
commit 79db7d931b
1 changed files with 4 additions and 2 deletions

View File

@ -331,13 +331,14 @@ def evaluate(model, criterion, ap, global_step, epoch):
assert mel_input.shape[1] % model.decoder.r == 0
# forward pass model
if c.bidirectional_decoder:
if c.bidirectional_decoder or c.double_decoder_consistency:
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
else:
decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
decoder_backward_output = None
alignments_backward = None
# set the alignment lengths wrt reduction factor for guided attention
if mel_lengths.max() % model.decoder.r != 0:
@ -349,7 +350,8 @@ def evaluate(model, criterion, ap, global_step, epoch):
loss_dict = criterion(postnet_output, decoder_output, mel_input,
linear_input, stop_tokens, stop_targets,
mel_lengths, decoder_backward_output,
alignments, alignment_lengths, text_lengths)
alignments, alignment_lengths, alignments_backward,
text_lengths)
if c.bidirectional_decoder:
keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_b_loss'].item(),
'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()})