diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 4d385a2a..da65c516 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -549,6 +549,8 @@ class VitsArgs(Coqpit): prosody_encoder_num_tokens: int = 5 use_prosody_enc_spk_reversal_classifier: bool = False + use_prosody_conditional_flow_module: bool = False + detach_dp_input: bool = True use_language_embedding: bool = False embedded_language_dim: int = 4 @@ -633,8 +635,8 @@ class Vits(BaseTTS): self.args.kernel_size_text_encoder, self.args.dropout_p_text_encoder, language_emb_dim=self.embedded_language_dim, - emotion_emb_dim=self.args.emotion_embedding_dim, - prosody_emb_dim=self.args.prosody_embedding_dim, + emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_prosody_conditional_flow_module else 0, + prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_prosody_conditional_flow_module else 0, ) self.posterior_encoder = PosteriorEncoder( @@ -664,9 +666,16 @@ class Vits(BaseTTS): if self.args.use_prosody_encoder: dp_cond_embedding_dim += self.args.prosody_embedding_dim + dp_extra_inp_dim = 0 + if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.use_prosody_conditional_flow_module: + dp_extra_inp_dim += self.args.emotion_embedding_dim + + if self.args.use_prosody_encoder and not self.args.use_prosody_conditional_flow_module: + dp_extra_inp_dim += self.args.prosody_embedding_dim + if self.args.use_sdp: self.duration_predictor = StochasticDurationPredictor( - self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim, + self.args.hidden_channels + dp_extra_inp_dim, 192, 3, self.args.dropout_p_duration_predictor, @@ -676,7 +685,7 @@ class Vits(BaseTTS): ) else: self.duration_predictor = DurationPredictor( - self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim, + self.args.hidden_channels + dp_extra_inp_dim, 256, 3, self.args.dropout_p_duration_predictor, @@ -698,11 +707,27 @@ class Vits(BaseTTS): hidden_channels=256, ) + if self.args.use_prosody_conditional_flow_module: + cond_embedding_dim = 0 + if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings: + cond_embedding_dim += self.args.emotion_embedding_dim + + if self.args.use_prosody_encoder: + cond_embedding_dim += self.args.prosody_embedding_dim + + self.prosody_conditional_module = ResidualCouplingBlocks( + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_flow, + dilation_rate=self.args.dilation_rate_flow, + num_layers=2, + cond_channels=cond_embedding_dim, + ) + if self.args.use_text_enc_spk_reversal_classifier: self.speaker_text_enc_reversal_classifier = ReversalClassifier( in_channels=self.args.hidden_channels - + self.args.emotion_embedding_dim - + self.args.prosody_embedding_dim, + + dp_extra_inp_dim, out_channels=self.num_speakers, hidden_channels=256, ) @@ -1097,7 +1122,17 @@ class Vits(BaseTTS): if self.args.use_prosody_enc_spk_reversal_classifier: _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) - x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb) + x, m_p, logs_p, x_mask = self.text_encoder( + x, + x_lengths, + lang_emb=lang_emb, + emo_emb=eg if not self.args.use_prosody_conditional_flow_module else None, + pros_emb=pros_emb if not self.args.use_prosody_conditional_flow_module else None + ) + + # conditional module + if self.args.use_prosody_conditional_flow_module: + m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb) # reversal speaker loss to force the encoder to be speaker identity free l_text_speaker = None @@ -1246,7 +1281,17 @@ class Vits(BaseTTS): z_pro, _, _, _ = self.posterior_encoder(pf, pf_lengths, g=g) pros_emb = self.prosody_encoder(z_pro).transpose(1, 2) - x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb) + x, m_p, logs_p, x_mask = self.text_encoder( + x, + x_lengths, + lang_emb=lang_emb, + emo_emb=eg if not self.args.use_prosody_conditional_flow_module else None, + pros_emb=pros_emb if not self.args.use_prosody_conditional_flow_module else None + ) + + # conditional module + if self.args.use_prosody_conditional_flow_module: + m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb) # duration predictor g_dp = g if self.args.condition_dp_on_speaker else None diff --git a/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py index 4b67a339..c80e6771 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py @@ -47,6 +47,8 @@ config.model_args.emotion_embedding_dim = 256 config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json" config.model_args.use_text_enc_spk_reversal_classifier = False + +config.model_args.use_prosody_conditional_flow_module = True # consistency loss # config.model_args.use_emotion_encoder_as_loss = True # config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"