diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 0c94f91f..57d36717 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -587,13 +587,12 @@ class VitsGeneratorLoss(nn.Module): @staticmethod def cosine_similarity_loss(gt_spk_emb, syn_spk_emb): - l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() - return l + return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() def forward( self, - waveform, - waveform_hat, + mel_slice, + mel_slice_hat, z_p, logs_q, m_p, @@ -609,8 +608,8 @@ class VitsGeneratorLoss(nn.Module): ): """ Shapes: - - waveform : :math:`[B, 1, T]` - - waveform_hat: :math:`[B, 1, T]` + - mel_slice : :math:`[B, 1, T]` + - mel_slice_hat: :math:`[B, 1, T]` - z_p: :math:`[B, C, T]` - logs_q: :math:`[B, C, T]` - m_p: :math:`[B, C, T]` @@ -623,30 +622,23 @@ class VitsGeneratorLoss(nn.Module): loss = 0.0 return_dict = {} z_mask = sequence_mask(z_len).float() - # compute mel spectrograms from the waveforms - mel = self.stft(waveform) - mel_hat = self.stft(waveform_hat) - # compute losses - loss_kl = self.kl_loss( - z_p=z_p, - logs_q=logs_q, - m_p=m_p, - logs_p=logs_p, - z_mask=z_mask.unsqueeze(1)) * self.kl_loss_alpha - loss_feat = self.feature_loss( - feats_real=feats_disc_real, - feats_generated=feats_disc_fake) * self.feat_loss_alpha + loss_kl = ( + self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1)) + * self.kl_loss_alpha + ) + loss_feat = ( + self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha + ) loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha - loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha + loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration if use_speaker_encoder_as_loss: loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha - loss += loss_se + loss = loss + loss_se return_dict["loss_spk_encoder"] = loss_se - # pass losses to the dict return_dict["loss_gen"] = loss_gen return_dict["loss_kl"] = loss_kl @@ -675,16 +667,18 @@ class VitsDiscriminatorLoss(nn.Module): loss += real_loss + fake_loss real_losses.append(real_loss.item()) fake_losses.append(fake_loss.item()) - return loss, real_losses, fake_losses def forward(self, scores_disc_real, scores_disc_fake): loss = 0.0 return_dict = {} - loss_disc, _, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake) + loss_disc, loss_disc_real, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake) return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha loss = loss + return_dict["loss_disc"] return_dict["loss"] = loss + + for i, ldr in enumerate(loss_disc_real): + return_dict[f"loss_disc_real_{i}"] = ldr return return_dict