diff --git a/TTS/tts/layers/vits/prosody_encoder.py b/TTS/tts/layers/vits/prosody_encoder.py index 7df2d9ff..bd7f3f33 100644 --- a/TTS/tts/layers/vits/prosody_encoder.py +++ b/TTS/tts/layers/vits/prosody_encoder.py @@ -7,6 +7,8 @@ class VitsGST(GST): super().__init__(*args, **kwargs) def forward(self, inputs, input_lengths=None, speaker_embedding=None): + if speaker_embedding is not None: + speaker_embedding = speaker_embedding.squeeze(-1) style_embed = super().forward(inputs, speaker_embedding=speaker_embedding) return style_embed, None @@ -16,8 +18,10 @@ class VitsVAE(CapacitronVAE): super().__init__(*args, **kwargs) self.beta = None - def forward(self, inputs, input_lengths=None): - VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths]) + def forward(self, inputs, input_lengths=None, speaker_embedding=None): + if speaker_embedding is not None: + speaker_embedding = speaker_embedding.squeeze(-1) + VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths], speaker_embedding=speaker_embedding) return VAE_embedding.to(inputs.device), [posterior_distribution, prior_distribution] @@ -25,6 +29,6 @@ class ResNetProsodyEncoder(ResNetSpeakerEncoder): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def forward(self, inputs, input_lengths=None): + def forward(self, inputs, input_lengths=None, speaker_embedding=None): style_embed = super().forward(inputs, l2_norm=True).unsqueeze(1) return style_embed, None \ No newline at end of file diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 0b09dad5..e63f57c4 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -555,8 +555,11 @@ class VitsArgs(Coqpit): # prosody encoder use_prosody_encoder: bool = False + use_pros_enc_input_as_pros_emb: bool = False prosody_encoder_type: str = "gst" detach_prosody_enc_input: bool = False + condition_pros_enc_on_speaker: bool = False + prosody_embedding_dim: int = 0 prosody_encoder_num_heads: int = 1 prosody_encoder_num_tokens: int = 5 @@ -715,27 +718,34 @@ class Vits(BaseTTS): ) if self.args.use_prosody_encoder: - if self.args.prosody_encoder_type == "gst": - self.prosody_encoder = VitsGST( - num_mel=self.args.hidden_channels, - num_heads=self.args.prosody_encoder_num_heads, - num_style_tokens=self.args.prosody_encoder_num_tokens, - gst_embedding_dim=self.args.prosody_embedding_dim, - ) - elif self.args.prosody_encoder_type == "vae": - self.prosody_encoder = VitsVAE( - num_mel=self.args.hidden_channels, - capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim, - ) - elif self.args.prosody_encoder_type == "resnet": - self.prosody_encoder = ResNetProsodyEncoder( - input_dim=self.args.hidden_channels, - proj_dim=self.args.prosody_embedding_dim, + if self.args.use_pros_enc_input_as_pros_emb: + self.prosody_embedding_squeezer = nn.Linear( + in_features=self.args.hidden_channels, out_features=self.args.prosody_embedding_dim ) else: - raise RuntimeError( - f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!" - ) + if self.args.prosody_encoder_type == "gst": + self.prosody_encoder = VitsGST( + num_mel=self.args.hidden_channels, + num_heads=self.args.prosody_encoder_num_heads, + num_style_tokens=self.args.prosody_encoder_num_tokens, + gst_embedding_dim=self.args.prosody_embedding_dim, + embedded_speaker_dim=self.cond_embedding_dim if self.args.condition_pros_enc_on_speaker else None, + ) + elif self.args.prosody_encoder_type == "vae": + self.prosody_encoder = VitsVAE( + num_mel=self.args.hidden_channels, + capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim, + speaker_embedding_dim=self.cond_embedding_dim if self.args.condition_pros_enc_on_speaker else None, + ) + elif self.args.prosody_encoder_type == "resnet": + self.prosody_encoder = ResNetProsodyEncoder( + input_dim=self.args.hidden_channels, + proj_dim=self.args.prosody_embedding_dim, + ) + else: + raise RuntimeError( + f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!" + ) if self.args.use_prosody_enc_spk_reversal_classifier: self.speaker_reversal_classifier = ReversalClassifier( in_channels=self.args.prosody_embedding_dim, @@ -1223,10 +1233,16 @@ class Vits(BaseTTS): l_pros_emotion = None if self.args.use_prosody_encoder: prosody_encoder_input = z_p if self.args.use_prosody_encoder_z_p_input else z - pros_emb, vae_outputs = self.prosody_encoder( - prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input, - y_lengths, - ) + if not self.args.use_pros_enc_input_as_pros_emb: + pros_emb, vae_outputs = self.prosody_encoder( + prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input, + y_lengths, + speaker_embedding=g if self.args.condition_pros_enc_on_speaker else None + ) + else: + pros_emb = prosody_encoder_input.mean(2).unsqueeze(1).detach() + pros_emb = F.normalize(self.prosody_embedding_squeezer(pros_emb.squeeze(1))).unsqueeze(1) + pros_emb = pros_emb.transpose(1, 2) if self.args.use_prosody_enc_spk_reversal_classifier: @@ -1427,13 +1443,19 @@ class Vits(BaseTTS): # extract posterior encoder feature pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device) z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=ssg) - if not self.args.use_prosody_encoder_z_p_input: - pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths) + if not self.args.use_pros_enc_input_as_pros_emb: + if not self.args.use_prosody_encoder_z_p_input: + pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths, speaker_embedding=ssg if self.args.condition_pros_enc_on_speaker else None) + else: + z_p_inf = self.flow(z_pro, z_pro_y_mask, g=ssg) + pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths, speaker_embedding=ssg if self.args.condition_pros_enc_on_speaker else None) else: - z_p_inf = self.flow(z_pro, z_pro_y_mask, g=ssg) - pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths) + prosody_encoder_input = self.flow(z_pro, z_pro_y_mask, g=ssg) if self.args.use_prosody_encoder_z_p_input else z_pro + pros_emb = prosody_encoder_input.mean(2).unsqueeze(1) + pros_emb = F.normalize(self.prosody_embedding_squeezer(pros_emb.squeeze(1))).unsqueeze(1) pros_emb = pros_emb.transpose(1, 2) + x, m_p, logs_p, x_mask = self.text_encoder( x, x_lengths, diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py index 0602f9a6..7dfa4f01 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -49,11 +49,17 @@ config.model_args.use_prosody_enc_emo_classifier = False config.model_args.use_text_enc_emo_classifier = False config.model_args.use_prosody_encoder_z_p_input = True -config.model_args.prosody_encoder_type = "resnet" +config.model_args.prosody_encoder_type = "gst" config.model_args.detach_prosody_enc_input = True + config.model_args.use_latent_discriminator = True config.model_args.use_noise_scale_predictor = False +config.model_args.condition_pros_enc_on_speaker = True + +config.model_args.use_pros_enc_input_as_pros_emb = True +config.model_args.use_prosody_embedding_squeezer = True +config.model_args.prosody_embedding_squeezer_input_dim = 192 # enable end2end loss config.model_args.use_end2end_loss = False