diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 0b668401..7ea87ebc 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -547,12 +547,14 @@ class VitsArgs(Coqpit): # prosody encoder use_prosody_encoder: bool = False prosody_encoder_type: str = "gst" + detach_prosody_enc_input: bool = False prosody_embedding_dim: int = 0 prosody_encoder_num_heads: int = 1 prosody_encoder_num_tokens: int = 5 use_prosody_encoder_z_p_input: bool = False use_prosody_enc_spk_reversal_classifier: bool = False use_prosody_enc_emo_classifier: bool = False + use_prosody_conditional_flow_module: bool = False prosody_conditional_flow_module_on_decoder: bool = False @@ -1153,10 +1155,11 @@ class Vits(BaseTTS): l_pros_speaker = None l_pros_emotion = None if self.args.use_prosody_encoder: - if not self.args.use_prosody_encoder_z_p_input: - pros_emb, vae_outputs = self.prosody_encoder(z, y_lengths) - else: - pros_emb, vae_outputs = self.prosody_encoder(z_p, y_lengths) + prosody_encoder_input = z_p if self.args.use_prosody_encoder_z_p_input else z + pros_emb, vae_outputs = self.prosody_encoder( + prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input, + y_lengths + ) pros_emb = pros_emb.transpose(1, 2) @@ -1337,10 +1340,10 @@ class Vits(BaseTTS): pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device) z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g) if not self.args.use_prosody_encoder_z_p_input: - pros_emb, vae_outputs = self.prosody_encoder(z_pro, pf_lengths) + pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths) else: z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g) - pros_emb, vae_outputs = self.prosody_encoder(z_p_inf, pf_lengths) + pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths) pros_emb = pros_emb.transpose(1, 2) diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py index 5cc6344f..e7cad601 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -50,6 +50,7 @@ config.model_args.use_text_enc_emo_classifier = True config.model_args.use_prosody_encoder_z_p_input = True config.model_args.prosody_encoder_type = "vae" +config.model_args.detach_prosody_enc_input = True config.save_json(config_path)