From a2aecea8f3a57006eade86bed4349a6a62f6205f Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 27 May 2022 12:53:56 -0300 Subject: [PATCH] Add VAE prosody encoder --- TTS/tts/configs/vits_config.py | 1 + TTS/tts/layers/losses.py | 13 ++++++- TTS/tts/layers/vits/prosody_encoder.py | 19 +++++++++ TTS/tts/models/vits.py | 39 ++++++++++++++----- ...t_vits_speaker_emb_with_prosody_encoder.py | 2 + 5 files changed, 63 insertions(+), 11 deletions(-) create mode 100644 TTS/tts/layers/vits/prosody_encoder.py diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index b6892d85..946133fc 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -117,6 +117,7 @@ class VitsConfig(BaseTTSConfig): consistency_loss_alpha: float = 1.0 speaker_classifier_loss_alpha: float = 2.0 emotion_classifier_loss_alpha: float = 4.0 + prosody_encoder_kl_loss_alpha: float = 5.0 # data loader params return_wav: bool = True diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 5f44dc22..00b00b77 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -592,6 +592,7 @@ class VitsGeneratorLoss(nn.Module): self.consistency_loss_alpha = c.consistency_loss_alpha self.emotion_classifier_alpha = c.emotion_classifier_loss_alpha self.speaker_classifier_alpha = c.speaker_classifier_loss_alpha + self.prosody_encoder_kl_loss_alpha = c.prosody_encoder_kl_loss_alpha self.stft = TorchSTFT( c.audio.fft_size, @@ -665,6 +666,7 @@ class VitsGeneratorLoss(nn.Module): use_encoder_consistency_loss=False, gt_cons_emb=None, syn_cons_emb=None, + vae_outputs=None, loss_prosody_enc_spk_rev_classifier=None, loss_prosody_enc_emo_classifier=None, loss_text_enc_spk_rev_classifier=None, @@ -725,7 +727,16 @@ class VitsGeneratorLoss(nn.Module): loss += loss_text_enc_emo_classifier return_dict["loss_text_enc_emo_classifier"] = loss_text_enc_emo_classifier - + if vae_outputs is not None: + posterior_distribution, prior_distribution = vae_outputs + # KL divergence term between the posterior and the prior + kl_term = torch.mean(torch.distributions.kl_divergence(posterior_distribution, prior_distribution)) + # multiply the loss by the alpha + kl_vae_loss = kl_term * self.prosody_encoder_kl_loss_alpha + + loss += kl_vae_loss + return_dict["loss_kl_vae"] = kl_vae_loss + # pass losses to the dict return_dict["loss_gen"] = loss_gen return_dict["loss_kl"] = loss_kl diff --git a/TTS/tts/layers/vits/prosody_encoder.py b/TTS/tts/layers/vits/prosody_encoder.py new file mode 100644 index 00000000..ea8d11f6 --- /dev/null +++ b/TTS/tts/layers/vits/prosody_encoder.py @@ -0,0 +1,19 @@ +from TTS.tts.layers.tacotron.gst_layers import GST +from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE + +class VitsGST(GST): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, inputs, input_lengths=None, speaker_embedding=None): + style_embed = super().forward(inputs, speaker_embedding=speaker_embedding) + return style_embed, None + +class VitsVAE(CapacitronVAE): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.beta = None + + def forward(self, inputs, input_lengths=None): + VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths]) + return VAE_embedding.to(inputs.device), [posterior_distribution, prior_distribution] diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 5713ccea..0b668401 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -19,10 +19,11 @@ from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.datasets.dataset import TTSDataset, _parse_sample from TTS.tts.layers.generic.classifier import ReversalClassifier from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor -from TTS.tts.layers.tacotron.gst_layers import GST +from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor + from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.emotions import EmotionManager from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask @@ -545,6 +546,7 @@ class VitsArgs(Coqpit): # prosody encoder use_prosody_encoder: bool = False + prosody_encoder_type: str = "gst" prosody_embedding_dim: int = 0 prosody_encoder_num_heads: int = 1 prosody_encoder_num_tokens: int = 5 @@ -698,11 +700,21 @@ class Vits(BaseTTS): ) if self.args.use_prosody_encoder: - self.prosody_encoder = GST( - num_mel=self.args.hidden_channels, - num_heads=self.args.prosody_encoder_num_heads, - num_style_tokens=self.args.prosody_encoder_num_tokens, - gst_embedding_dim=self.args.prosody_embedding_dim, + if self.args.prosody_encoder_type == 'gst': + self.prosody_encoder = VitsGST( + num_mel=self.args.hidden_channels, + num_heads=self.args.prosody_encoder_num_heads, + num_style_tokens=self.args.prosody_encoder_num_tokens, + gst_embedding_dim=self.args.prosody_embedding_dim, + ) + elif self.args.prosody_encoder_type == 'vae': + self.prosody_encoder = VitsVAE( + num_mel=self.args.hidden_channels, + capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim, + ) + else: + raise RuntimeError( + f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!" ) if self.args.use_prosody_enc_spk_reversal_classifier: self.speaker_reversal_classifier = ReversalClassifier( @@ -1142,9 +1154,11 @@ class Vits(BaseTTS): l_pros_emotion = None if self.args.use_prosody_encoder: if not self.args.use_prosody_encoder_z_p_input: - pros_emb = self.prosody_encoder(z).transpose(1, 2) + pros_emb, vae_outputs = self.prosody_encoder(z, y_lengths) else: - pros_emb = self.prosody_encoder(z_p).transpose(1, 2) + pros_emb, vae_outputs = self.prosody_encoder(z_p, y_lengths) + + pros_emb = pros_emb.transpose(1, 2) if self.args.use_prosody_enc_spk_reversal_classifier: _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) @@ -1253,6 +1267,7 @@ class Vits(BaseTTS): "gt_cons_emb": gt_cons_emb, "syn_cons_emb": syn_cons_emb, "slice_ids": slice_ids, + "vae_outputs": vae_outputs, "loss_prosody_enc_spk_rev_classifier": l_pros_speaker, "loss_prosody_enc_emo_classifier": l_pros_emotion, "loss_text_enc_spk_rev_classifier": l_text_speaker, @@ -1322,10 +1337,12 @@ class Vits(BaseTTS): pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device) z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g) if not self.args.use_prosody_encoder_z_p_input: - pros_emb = self.prosody_encoder(z_pro).transpose(1, 2) + pros_emb, vae_outputs = self.prosody_encoder(z_pro, pf_lengths) else: z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g) - pros_emb = self.prosody_encoder(z_p_inf).transpose(1, 2) + pros_emb, vae_outputs = self.prosody_encoder(z_p_inf, pf_lengths) + + pros_emb = pros_emb.transpose(1, 2) x, m_p, logs_p, x_mask = self.text_encoder( x, @@ -1469,6 +1486,7 @@ class Vits(BaseTTS): else: raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.") # emotion embedding + ge_src, ge_tgt = None, None if self.args.use_emotion_embedding and ref_emotion is not None and target_emotion is not None: ge_src = self.emb_g(ref_emotion).unsqueeze(-1) ge_tgt = self.emb_g(target_emotion).unsqueeze(-1) @@ -1602,6 +1620,7 @@ class Vits(BaseTTS): or self.args.use_emotion_encoder_as_loss, gt_cons_emb=self.model_outputs_cache["gt_cons_emb"], syn_cons_emb=self.model_outputs_cache["syn_cons_emb"], + vae_outputs=self.model_outputs_cache["vae_outputs"], loss_prosody_enc_spk_rev_classifier=self.model_outputs_cache["loss_prosody_enc_spk_rev_classifier"], loss_prosody_enc_emo_classifier=self.model_outputs_cache["loss_prosody_enc_emo_classifier"], loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"], diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py index d3b3051e..5cc6344f 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -49,6 +49,8 @@ config.model_args.use_prosody_enc_emo_classifier = False config.model_args.use_text_enc_emo_classifier = True config.model_args.use_prosody_encoder_z_p_input = True +config.model_args.prosody_encoder_type = "vae" + config.save_json(config_path) # train the model for one epoch