diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 17c2cb49..5713ccea 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -548,6 +548,7 @@ class VitsArgs(Coqpit): prosody_embedding_dim: int = 0 prosody_encoder_num_heads: int = 1 prosody_encoder_num_tokens: int = 5 + use_prosody_encoder_z_p_input: bool = False use_prosody_enc_spk_reversal_classifier: bool = False use_prosody_enc_emo_classifier: bool = False @@ -1132,17 +1133,24 @@ class Vits(BaseTTS): # posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) + # flow layers + z_p = self.flow(z, y_mask, g=g) + # prosody embedding pros_emb = None l_pros_speaker = None l_pros_emotion = None if self.args.use_prosody_encoder: - pros_emb = self.prosody_encoder(z).transpose(1, 2) + if not self.args.use_prosody_encoder_z_p_input: + pros_emb = self.prosody_encoder(z).transpose(1, 2) + else: + pros_emb = self.prosody_encoder(z_p).transpose(1, 2) + if self.args.use_prosody_enc_spk_reversal_classifier: _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) if self.args.use_prosody_enc_emo_classifier: _, l_pros_emotion = self.pros_enc_emotion_classifier(pros_emb.transpose(1, 2), eid, x_mask=None) - + x, m_p, logs_p, x_mask = self.text_encoder( x, x_lengths, @@ -1156,9 +1164,6 @@ class Vits(BaseTTS): if self.args.use_text_enc_spk_reversal_classifier: _, l_text_speaker = self.speaker_text_enc_reversal_classifier(x.transpose(1, 2), sid, x_mask=None) - # flow layers - z_p = self.flow(z, y_mask, g=g) - # reversal speaker loss to force the encoder to be speaker identity free l_text_emotion = None if self.args.use_text_enc_emo_classifier: @@ -1315,8 +1320,12 @@ class Vits(BaseTTS): if self.args.use_prosody_encoder: # extract posterior encoder feature pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device) - z_pro, _, _, _ = self.posterior_encoder(pf, pf_lengths, g=g) - pros_emb = self.prosody_encoder(z_pro).transpose(1, 2) + z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g) + if not self.args.use_prosody_encoder_z_p_input: + pros_emb = self.prosody_encoder(z_pro).transpose(1, 2) + else: + z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g) + pros_emb = self.prosody_encoder(z_p_inf).transpose(1, 2) x, m_p, logs_p, x_mask = self.text_encoder( x, diff --git a/TTS/tts/utils/emotions.py b/TTS/tts/utils/emotions.py index 9db1aaab..57cd8060 100644 --- a/TTS/tts/utils/emotions.py +++ b/TTS/tts/utils/emotions.py @@ -94,7 +94,7 @@ class EmotionManager(EmbeddingManager): EmotionEncoder: Emotion encoder object. """ emotion_manager = None - if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False): + if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False): if get_from_config_or_model_args_with_default(config, "emotions_ids_file", None): emotion_manager = EmotionManager( emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None) @@ -106,7 +106,7 @@ class EmotionManager(EmbeddingManager): ) ) - if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False): + if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False): if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None): emotion_manager = EmotionManager( embeddings_file_path=get_from_config_or_model_args_with_default( diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py index e6f94059..d3b3051e 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -45,8 +45,9 @@ config.model_args.use_prosody_encoder = True config.model_args.prosody_embedding_dim = 64 # active classifier config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json" -config.model_args.use_prosody_enc_emo_classifier = True +config.model_args.use_prosody_enc_emo_classifier = False config.model_args.use_text_enc_emo_classifier = True +config.model_args.use_prosody_encoder_z_p_input = True config.save_json(config_path)