diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 79654c07..ec5f0811 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -200,7 +200,7 @@ class TacotronLoss(torch.nn.Module): self.criterion = nn.L1Loss() if c.model in ["Tacotron" ] else nn.MSELoss() # differential spectral loss - if c.diff_spec_loss_alpha > 0: + if c.diff_spec_alpha > 0: self.criterion_diff_spec = DifferentailSpectralLoss(loss_func=self.criterion) # guided attention loss if c.ga_alpha > 0: @@ -254,11 +254,17 @@ class TacotronLoss(torch.nn.Module): torch.flip(decoder_b_output, dims=(1, )), mel_input, output_lens) else: +<<<<<<< HEAD decoder_b_loss = self.criterion( torch.flip(decoder_b_output, dims=(1, )), mel_input) decoder_c_loss = torch.nn.functional.l1_loss( torch.flip(decoder_b_output, dims=(1, )), decoder_output) loss += decoder_b_loss + decoder_c_loss +======= + decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1, )), mel_input) + decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_b_output, dims=(1, )), decoder_output) + loss += self.decoder_alpha * (decoder_b_loss + decoder_c_loss) +>>>>>>> differential spectral loss return_dict['decoder_b_loss'] = decoder_b_loss return_dict['decoder_c_loss'] = decoder_c_loss @@ -267,9 +273,14 @@ class TacotronLoss(torch.nn.Module): decoder_b_loss = self.criterion(decoder_b_output, mel_input, output_lens) # decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output) +<<<<<<< HEAD attention_c_loss = torch.nn.functional.l1_loss( alignments, alignments_backwards) loss += decoder_b_loss + attention_c_loss +======= + attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards) + loss += self.decoder_alpha * (decoder_b_loss + attention_c_loss) +>>>>>>> differential spectral loss return_dict['decoder_coarse_loss'] = decoder_b_loss return_dict['decoder_ddc_loss'] = attention_c_loss @@ -280,7 +291,7 @@ class TacotronLoss(torch.nn.Module): return_dict['ga_loss'] = ga_loss * self.ga_alpha # differential spectral loss - if self.config.diff_spec_loss_alpha > 0: + if self.config.diff_spec_alpha > 0: diff_spec_loss = self.criterion_diff_spec(postnet_output, mel_input, output_lens) loss += diff_spec_loss * self.diff_spec_alpha return_dict['diff_spec_loss'] = diff_spec_loss