From 23f65df6c6cbe9ee20547841c8d40e0a87b2c736 Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 4 Jun 2020 14:27:40 +0200 Subject: [PATCH] add DDC loss --- layers/losses.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/layers/losses.py b/layers/losses.py index 608e247d..f7745b6e 100644 --- a/layers/losses.py +++ b/layers/losses.py @@ -184,7 +184,7 @@ class TacotronLoss(torch.nn.Module): def forward(self, postnet_output, decoder_output, mel_input, linear_input, stopnet_output, stopnet_target, output_lens, decoder_b_output, - alignments, alignment_lens, input_lens): + alignments, alignment_lens, alignments_backwards, input_lens): return_dict = {} # decoder and postnet losses @@ -226,6 +226,15 @@ class TacotronLoss(torch.nn.Module): return_dict['decoder_b_loss'] = decoder_b_loss return_dict['decoder_c_loss'] = decoder_c_loss + # double decoder consistency loss (if enabled) + if self.config.double_decoder_consistency: + decoder_b_loss = self.criterion(decoder_b_output, mel_input, output_lens) + # decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output) + attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards) + loss += decoder_b_loss + attention_c_loss + return_dict['decoder_coarse_loss'] = decoder_b_loss + return_dict['decoder_ddc_loss'] = attention_c_loss + # guided attention loss (if enabled) if self.config.ga_alpha > 0: ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens)