diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 1646289b..f9cdd2e4 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -547,6 +547,7 @@ class VitsArgs(Coqpit): prosody_embedding_dim: int = 0 prosody_encoder_num_heads: int = 1 prosody_encoder_num_tokens: int = 5 + use_prosody_enc_spk_reversal_classifier: bool = False detach_dp_input: bool = True use_language_embedding: bool = False @@ -690,11 +691,12 @@ class Vits(BaseTTS): num_style_tokens=self.args.prosody_encoder_num_tokens, gst_embedding_dim=self.args.prosody_embedding_dim, ) - self.speaker_pros_enc_reversal_classifier = ReversalClassifier( - in_channels=self.args.prosody_embedding_dim, - out_channels=self.num_speakers, - hidden_channels=256, - ) + if self.args.use_prosody_enc_spk_reversal_classifier: + self.speaker_reversal_classifier = ReversalClassifier( + in_channels=self.args.prosody_embedding_dim, + out_channels=self.num_speakers, + hidden_channels=256, + ) if self.args.use_text_enc_spk_reversal_classifier: self.speaker_text_enc_reversal_classifier = ReversalClassifier( @@ -1089,7 +1091,8 @@ class Vits(BaseTTS): l_pros_speaker = None if self.args.use_prosody_encoder: pros_emb = self.prosody_encoder(z).transpose(1, 2) - _, l_pros_speaker = self.speaker_pros_enc_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) + if self.args.use_prosody_enc_spk_reversal_classifier: + _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)