mirror of https://github.com/coqui-ai/TTS.git
train.py update for DDC
This commit is contained in:
parent
fbf1689d18
commit
4aaf57b50b
10
train.py
10
train.py
|
@ -158,13 +158,14 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
optimizer_st.zero_grad()
|
||||
|
||||
# forward pass model
|
||||
if c.bidirectional_decoder:
|
||||
if c.bidirectional_decoder or c.double_decoder_consistency:
|
||||
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids)
|
||||
decoder_backward_output = None
|
||||
alignments_backward = None
|
||||
|
||||
# set the alignment lengths wrt reduction factor for guided attention
|
||||
if mel_lengths.max() % model.decoder.r != 0:
|
||||
|
@ -176,7 +177,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
loss_dict = criterion(postnet_output, decoder_output, mel_input,
|
||||
linear_input, stop_tokens, stop_targets,
|
||||
mel_lengths, decoder_backward_output,
|
||||
alignments, alignment_lengths, text_lengths)
|
||||
alignments, alignment_lengths, alignments_backward,
|
||||
text_lengths)
|
||||
if c.bidirectional_decoder:
|
||||
keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_backward_loss'].item(),
|
||||
'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()})
|
||||
|
|
Loading…
Reference in New Issue