From d6d8d0e3e1856886dc6d43d36a8a783dd9bf4aab Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 3 Jun 2022 13:05:40 +0000 Subject: [PATCH] Fix the VITS GAN loss --- TTS/tts/configs/vits_config.py | 6 ++--- TTS/tts/layers/losses.py | 27 +++++++++++++------ TTS/tts/models/vits.py | 1 + ...t_vits_speaker_emb_with_prosody_encoder.py | 6 ++--- 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index e848b827..981926e1 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -118,9 +118,9 @@ class VitsConfig(BaseTTSConfig): speaker_classifier_loss_alpha: float = 2.0 emotion_classifier_loss_alpha: float = 4.0 prosody_encoder_kl_loss_alpha: float = 5.0 - disc_latent_loss_alpha: float = 1.5 - gen_latent_loss_alpha: float = 1.5 - feat_latent_loss_alpha: float = 1.5 + disc_latent_loss_alpha: float = 5.0 + gen_latent_loss_alpha: float = 5.0 + feat_latent_loss_alpha: float = 108.0 # data loader params return_wav: bool = True diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 6b21ba72..118286a7 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -9,7 +9,13 @@ from torch.nn import functional from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.ssim import ssim from TTS.utils.audio import TorchSTFT - +from TTS.vocoder.layers.losses import ( + MelganFeatureLoss, + MSEDLoss, + MSEGLoss, + _apply_D_loss, + _apply_G_adv_loss, +) # pylint: disable=abstract-method # relates https://github.com/pytorch/pytorch/issues/42305 @@ -607,6 +613,9 @@ class VitsGeneratorLoss(nn.Module): use_mel=True, do_amp_to_db=True, ) + if c.model_args.use_latent_discriminator: + self.latent_feat_match_loss = MelganFeatureLoss() + self.gen_latent_gan_loss = MSEGLoss() @staticmethod def feature_loss(feats_real, feats_generated): @@ -735,13 +744,11 @@ class VitsGeneratorLoss(nn.Module): if scores_disc_mp is not None and feats_disc_mp is not None and feats_disc_zp is not None: # feature loss - loss_feat_latent = ( - self.feature_loss(feats_real=feats_disc_zp, feats_generated=feats_disc_mp) * self.feat_latent_loss_alpha - ) + loss_feat_latent = self.latent_feat_match_loss(feats_disc_mp, feats_disc_zp) * self.feat_latent_loss_alpha return_dict["loss_feat_latent"] = loss_feat_latent loss += return_dict["loss_feat_latent"] # gen loss - loss_gen_latent = self.generator_loss(scores_fake=scores_disc_mp)[0] * self.gen_latent_loss_alpha + loss_gen_latent = _apply_G_adv_loss(scores_disc_mp, self.gen_latent_gan_loss) * self.gen_latent_loss_alpha return_dict["loss_gen_latent"] = loss_gen_latent loss += return_dict["loss_gen_latent"] @@ -801,7 +808,9 @@ class VitsDiscriminatorLoss(nn.Module): def __init__(self, c: Coqpit): super().__init__() self.disc_loss_alpha = c.disc_loss_alpha - self.disc_latent_loss_alpha = c.disc_latent_loss_alpha + if c.model_args.use_latent_discriminator: + self.disc_latent_loss_alpha = c.disc_latent_loss_alpha + self.disc_latent_gan_loss = MSEDLoss() @staticmethod @@ -833,9 +842,11 @@ class VitsDiscriminatorLoss(nn.Module): # latent discriminator if scores_disc_zp is not None and scores_disc_mp is not None: - loss_disc_latent, _, _ = self.discriminator_loss( - scores_real=scores_disc_zp, scores_fake=scores_disc_mp + loss_disc_latent, loss_disc_latent_zp, loss_disc_latent_mp = _apply_D_loss( + scores_fake=scores_disc_mp, scores_real=scores_disc_zp, loss_func=self.disc_latent_gan_loss ) + return_dict["loss_disc_latent_mp"] = loss_disc_latent_mp + return_dict["loss_disc_latent_zp"] = loss_disc_latent_zp return_dict["loss_disc_latent"] = loss_disc_latent * self.disc_latent_loss_alpha return_dict["loss"] += return_dict["loss_disc_latent"] diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 4e2cadfe..995467e5 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -36,6 +36,7 @@ from TTS.tts.utils.visual import plot_alignment from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results + ############################## # IO / Feature extraction ############################## diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py index 91ecfaeb..aa47a2aa 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -12,8 +12,8 @@ output_path = os.path.join(get_tests_output_path(), "train_outputs") config = VitsConfig( - batch_size=3, - eval_batch_size=3, + batch_size=2, + eval_batch_size=2, num_loader_workers=0, num_eval_loader_workers=0, text_cleaner="english_cleaners", @@ -52,7 +52,7 @@ config.model_args.use_prosody_encoder_z_p_input = True config.model_args.prosody_encoder_type = "vae" config.model_args.detach_prosody_enc_input = True -config.model_args.use_latent_discriminator = True +config.model_args.use_latent_discriminator = False # enable end2end loss config.model_args.use_end2end_loss = False