From e07fcc7a8c7cc01ef18745d807ecd0ac849b1269 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 2 Jun 2022 18:40:57 -0300 Subject: [PATCH] Add text encoder adversarial loss on the VITS --- TTS/tts/configs/vits_config.py | 3 ++ TTS/tts/layers/losses.py | 34 +++++++++++-- TTS/tts/layers/vits/discriminator.py | 49 +++++++++++++++++-- TTS/tts/models/vits.py | 22 ++++++--- ...t_vits_speaker_emb_with_prosody_encoder.py | 7 +-- 5 files changed, 98 insertions(+), 17 deletions(-) diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index 946133fc..e848b827 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -118,6 +118,9 @@ class VitsConfig(BaseTTSConfig): speaker_classifier_loss_alpha: float = 2.0 emotion_classifier_loss_alpha: float = 4.0 prosody_encoder_kl_loss_alpha: float = 5.0 + disc_latent_loss_alpha: float = 1.5 + gen_latent_loss_alpha: float = 1.5 + feat_latent_loss_alpha: float = 1.5 # data loader params return_wav: bool = True diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 2fb39a63..6b21ba72 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -593,6 +593,8 @@ class VitsGeneratorLoss(nn.Module): self.emotion_classifier_alpha = c.emotion_classifier_loss_alpha self.speaker_classifier_alpha = c.speaker_classifier_loss_alpha self.prosody_encoder_kl_loss_alpha = c.prosody_encoder_kl_loss_alpha + self.feat_latent_loss_alpha = c.feat_latent_loss_alpha + self.gen_latent_loss_alpha = c.gen_latent_loss_alpha self.stft = TorchSTFT( c.audio.fft_size, @@ -671,6 +673,9 @@ class VitsGeneratorLoss(nn.Module): loss_prosody_enc_emo_classifier=None, loss_text_enc_spk_rev_classifier=None, loss_text_enc_emo_classifier=None, + scores_disc_mp=None, + feats_disc_mp=None, + feats_disc_zp=None, end2end_info=None, ): """ @@ -728,6 +733,18 @@ class VitsGeneratorLoss(nn.Module): loss += loss_text_enc_emo_classifier return_dict["loss_text_enc_emo_classifier"] = loss_text_enc_emo_classifier + if scores_disc_mp is not None and feats_disc_mp is not None and feats_disc_zp is not None: + # feature loss + loss_feat_latent = ( + self.feature_loss(feats_real=feats_disc_zp, feats_generated=feats_disc_mp) * self.feat_latent_loss_alpha + ) + return_dict["loss_feat_latent"] = loss_feat_latent + loss += return_dict["loss_feat_latent"] + # gen loss + loss_gen_latent = self.generator_loss(scores_fake=scores_disc_mp)[0] * self.gen_latent_loss_alpha + return_dict["loss_gen_latent"] = loss_gen_latent + loss += return_dict["loss_gen_latent"] + if vae_outputs is not None: posterior_distribution, prior_distribution = vae_outputs # KL divergence term between the posterior and the prior @@ -784,6 +801,8 @@ class VitsDiscriminatorLoss(nn.Module): def __init__(self, c: Coqpit): super().__init__() self.disc_loss_alpha = c.disc_loss_alpha + self.disc_latent_loss_alpha = c.disc_latent_loss_alpha + @staticmethod def discriminator_loss(scores_real, scores_fake): @@ -800,19 +819,26 @@ class VitsDiscriminatorLoss(nn.Module): fake_losses.append(fake_loss.item()) return loss, real_losses, fake_losses - def forward(self, scores_disc_real, scores_disc_fake, end2end_info=None): - loss = 0.0 + def forward(self, scores_disc_real, scores_disc_fake, scores_disc_zp=None, scores_disc_mp=None, end2end_info=None): return_dict = {} + return_dict["loss"] = 0.0 loss_disc, loss_disc_real, _ = self.discriminator_loss( scores_real=scores_disc_real, scores_fake=scores_disc_fake ) return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha - loss = loss + return_dict["loss_disc"] - return_dict["loss"] = loss + return_dict["loss"] += return_dict["loss_disc"] for i, ldr in enumerate(loss_disc_real): return_dict[f"loss_disc_real_{i}"] = ldr + # latent discriminator + if scores_disc_zp is not None and scores_disc_mp is not None: + loss_disc_latent, _, _ = self.discriminator_loss( + scores_real=scores_disc_zp, scores_fake=scores_disc_mp + ) + return_dict["loss_disc_latent"] = loss_disc_latent * self.disc_latent_loss_alpha + return_dict["loss"] += return_dict["loss_disc_latent"] + if end2end_info is not None: loss_disc_end2end, loss_disc_real_end2end, _ = self.discriminator_loss( scores_real=end2end_info["scores_disc_real"], scores_fake=end2end_info["scores_disc_fake"], diff --git a/TTS/tts/layers/vits/discriminator.py b/TTS/tts/layers/vits/discriminator.py index 148f283c..e65a3eba 100644 --- a/TTS/tts/layers/vits/discriminator.py +++ b/TTS/tts/layers/vits/discriminator.py @@ -58,13 +58,16 @@ class VitsDiscriminator(nn.Module): use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. """ - def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False): + def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False, use_latent_disc=False): super().__init__() self.nets = nn.ModuleList() self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm)) self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]) + self.disc_latent = None + if use_latent_disc: + self.disc_latent = LatentDiscriminator(use_spectral_norm=use_spectral_norm) - def forward(self, x, x_hat=None): + def forward(self, x, x_hat=None, m_p=None, z_p=None): """ Args: x (Tensor): ground truth waveform. @@ -86,4 +89,44 @@ class VitsDiscriminator(nn.Module): x_hat_score, x_hat_feat = net(x_hat) x_hat_scores.append(x_hat_score) x_hat_feats.append(x_hat_feat) - return x_scores, x_feats, x_hat_scores, x_hat_feats + + # variables latent disc + mp_scores, zp_scores, mp_feats, zp_feats = None, None, None, None + if self.disc_latent is not None: + if m_p is not None: + mp_scores, mp_feats = self.disc_latent(m_p.unsqueeze(1)) + if z_p is not None: + zp_scores, zp_feats = self.disc_latent(z_p.unsqueeze(1)) + + return x_scores, x_feats, x_hat_scores, x_hat_feats, mp_scores, mp_feats, zp_scores, zp_feats + + +class LatentDiscriminator(nn.Module): + """Discriminator with the same architecture as the Univnet SpecDiscriminator""" + + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm + self.discriminators = nn.ModuleList( + [ + norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), + ] + ) + + self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1)) + + def forward(self, y): + fmap = [] + for _, d in enumerate(self.discriminators): + y = d(y) + y = torch.nn.functional.leaky_relu(y, 0.1) + fmap.append(y) + + y = self.out(y) + fmap.append(y) + + return torch.flatten(y, 1, -1), fmap diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 723ad2ab..4e2cadfe 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -20,7 +20,7 @@ from TTS.tts.datasets.dataset import TTSDataset, _parse_sample from TTS.tts.layers.generic.classifier import ReversalClassifier from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE -from TTS.tts.layers.vits.discriminator import VitsDiscriminator +from TTS.tts.layers.vits.discriminator import VitsDiscriminator, LatentDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor @@ -563,6 +563,8 @@ class VitsArgs(Coqpit): use_end2end_loss: bool = False use_soft_dtw: bool = False + use_latent_discriminator: bool = False + detach_dp_input: bool = True use_language_embedding: bool = False embedded_language_dim: int = 4 @@ -789,6 +791,7 @@ class Vits(BaseTTS): self.disc = VitsDiscriminator( periods=self.args.periods_multi_period_discriminator, use_spectral_norm=self.args.use_spectral_norm_disriminator, + use_latent_disc=self.args.use_latent_discriminator, ) def init_multispeaker(self, config: Coqpit): @@ -1631,13 +1634,13 @@ class Vits(BaseTTS): self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init # compute scores and features - scores_disc_fake, _, scores_disc_real, _ = self.disc( - outputs["model_outputs"].detach(), outputs["waveform_seg"] + scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _= self.disc( + outputs["model_outputs"].detach(), outputs["waveform_seg"], outputs["m_p"].detach(), outputs["z_p"].detach() ) end2end_info = None if self.args.use_end2end_loss: - scores_disc_fake_end2end, _, scores_disc_real_end2end, _ = self.disc( + scores_disc_fake_end2end, _, scores_disc_real_end2end, _, _, _, _, _ = self.disc( outputs["end2end_info"]["model_outputs"].detach(), self.model_outputs_cache["end2end_info"]["waveform_seg"] ) end2end_info = {"scores_disc_real": scores_disc_real_end2end, "scores_disc_fake": scores_disc_fake_end2end} @@ -1647,6 +1650,8 @@ class Vits(BaseTTS): loss_dict = criterion[optimizer_idx]( scores_disc_real, scores_disc_fake, + scores_disc_zp, + scores_disc_mp, end2end_info=end2end_info, ) return outputs, loss_dict @@ -1678,12 +1683,12 @@ class Vits(BaseTTS): ) # compute discriminator scores and features - scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc( - self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + scores_disc_fake, feats_disc_fake, _, feats_disc_real, scores_disc_mp, feats_disc_mp, _, feats_disc_zp = self.disc( + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"], self.model_outputs_cache["m_p"], self.model_outputs_cache["z_p"].detach() ) if self.args.use_end2end_loss: - scores_disc_fake_end2end, feats_disc_fake_end2end, _, feats_disc_real_end2end = self.disc( + scores_disc_fake_end2end, feats_disc_fake_end2end, _, feats_disc_real_end2end, _, _, _, _, _ = self.disc( self.model_outputs_cache["end2end_info"]["model_outputs"], self.model_outputs_cache["end2end_info"]["waveform_seg"] ) self.model_outputs_cache["end2end_info"]["scores_disc_fake"] = scores_disc_fake_end2end @@ -1713,6 +1718,9 @@ class Vits(BaseTTS): loss_prosody_enc_emo_classifier=self.model_outputs_cache["loss_prosody_enc_emo_classifier"], loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"], loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"], + scores_disc_mp=scores_disc_mp, + feats_disc_mp=feats_disc_mp, + feats_disc_zp=feats_disc_zp, end2end_info=self.model_outputs_cache["end2end_info"], ) diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py index 6fa4a536..91ecfaeb 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -12,8 +12,8 @@ output_path = os.path.join(get_tests_output_path(), "train_outputs") config = VitsConfig( - batch_size=2, - eval_batch_size=2, + batch_size=3, + eval_batch_size=3, num_loader_workers=0, num_eval_loader_workers=0, text_cleaner="english_cleaners", @@ -52,8 +52,9 @@ config.model_args.use_prosody_encoder_z_p_input = True config.model_args.prosody_encoder_type = "vae" config.model_args.detach_prosody_enc_input = True +config.model_args.use_latent_discriminator = True # enable end2end loss -config.model_args.use_end2end_loss = True +config.model_args.use_end2end_loss = False config.mixed_precision = False