mirror of https://github.com/coqui-ai/TTS.git
add DDC loss
This commit is contained in:
parent
cd93b7b351
commit
23f65df6c6
|
@ -184,7 +184,7 @@ class TacotronLoss(torch.nn.Module):
|
|||
|
||||
def forward(self, postnet_output, decoder_output, mel_input, linear_input,
|
||||
stopnet_output, stopnet_target, output_lens, decoder_b_output,
|
||||
alignments, alignment_lens, input_lens):
|
||||
alignments, alignment_lens, alignments_backwards, input_lens):
|
||||
|
||||
return_dict = {}
|
||||
# decoder and postnet losses
|
||||
|
@ -226,6 +226,15 @@ class TacotronLoss(torch.nn.Module):
|
|||
return_dict['decoder_b_loss'] = decoder_b_loss
|
||||
return_dict['decoder_c_loss'] = decoder_c_loss
|
||||
|
||||
# double decoder consistency loss (if enabled)
|
||||
if self.config.double_decoder_consistency:
|
||||
decoder_b_loss = self.criterion(decoder_b_output, mel_input, output_lens)
|
||||
# decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output)
|
||||
attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards)
|
||||
loss += decoder_b_loss + attention_c_loss
|
||||
return_dict['decoder_coarse_loss'] = decoder_b_loss
|
||||
return_dict['decoder_ddc_loss'] = attention_c_loss
|
||||
|
||||
# guided attention loss (if enabled)
|
||||
if self.config.ga_alpha > 0:
|
||||
ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens)
|
||||
|
|
Loading…
Reference in New Issue