diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index a32cd428..c71a4189 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -118,8 +118,8 @@ class VitsConfig(BaseTTSConfig): speaker_classifier_loss_alpha: float = 2.0 emotion_classifier_loss_alpha: float = 4.0 prosody_encoder_kl_loss_alpha: float = 5.0 - disc_latent_loss_alpha: float = 5.0 - gen_latent_loss_alpha: float = 5.0 + disc_latent_loss_alpha: float = 1.0 + gen_latent_loss_alpha: float = 1.0 feat_latent_loss_alpha: float = 108.0 pitch_loss_alpha: float = 5.0 z_decoder_loss_alpha: float = 45.0 diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index c985ba82..e0010250 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -788,7 +788,7 @@ class Vits(BaseTTS): self.args.dropout_p_text_encoder, language_emb_dim=self.embedded_language_dim, emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_noise_scale_predictor else 0, - prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module else 0, + prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else 0, pitch_dim=self.args.pitch_embedding_dim if self.args.use_pitch and self.args.use_pitch_on_enc_input else 0, ) @@ -827,7 +827,7 @@ class Vits(BaseTTS): ) and not self.args.use_noise_scale_predictor: dp_extra_inp_dim += self.args.emotion_embedding_dim - if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module: + if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder: dp_extra_inp_dim += self.args.prosody_embedding_dim if self.args.use_pitch and self.args.use_pitch_on_enc_input: @@ -896,7 +896,7 @@ class Vits(BaseTTS): self.args.pitch_predictor_hidden_channels, self.args.pitch_predictor_kernel_size, self.args.pitch_predictor_dropout_p, - cond_channels=dp_cond_embedding_dim, + cond_channels=self.cond_embedding_dim, language_emb_dim=self.embedded_language_dim, ) @@ -1477,7 +1477,7 @@ class Vits(BaseTTS): if self.args.use_pitch and self.args.use_pitch_on_enc_input: if alignments is None: raise RuntimeError(" [!] For condition the pitch on the Text Encoder you need to provide external alignments !") - pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(x, x_lengths, pitch, alignments.sum(3), g_dp) + pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(x, x_lengths, pitch, alignments.sum(3), g) # posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) @@ -1514,7 +1514,7 @@ class Vits(BaseTTS): x_lengths, lang_emb=lang_emb, emo_emb=eg if not self.args.use_noise_scale_predictor else None, - pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module else None, + pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None, pitch_emb=gt_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None, ) @@ -1536,10 +1536,20 @@ class Vits(BaseTTS): outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb) + z_p_avg = None + if self.args.use_latent_discriminator: + # average the z_p for the latent discriminator + z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze()) + conditional_module_loss = None + new_m_p = None if self.args.use_encoder_conditional_module: g_cond = None cond_module_input = x + if self.args.use_pitch and not self.args.use_pitch_on_enc_input: + pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(cond_module_input, x_lengths, pitch, attn.sum(3), g) + cond_module_input = cond_module_input + gt_avg_pitch_emb + if self.args.use_prosody_encoder: if g_cond is None: g_cond = pros_emb @@ -1549,18 +1559,17 @@ class Vits(BaseTTS): if g_cond is not None: cond_module_input = torch.cat((cond_module_input, g_cond.expand(-1, -1, cond_module_input.size(2))), dim=1) - new_m_p = self.encoder_conditional_module(cond_module_input, x_mask) - z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze()).detach() - conditional_module_loss = torch.nn.functional.l1_loss(new_m_p * x_mask, z_p_avg) - - if self.args.use_pitch and not self.args.use_pitch_on_enc_input: - pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp) - m_p = m_p + gt_avg_pitch_emb + new_m_p = self.encoder_conditional_module(cond_module_input, x_mask) * x_mask + if z_p_avg is None: + z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze()).detach() + else: + z_p_avg = z_p_avg.detach() - z_p_avg = None - if self.args.use_latent_discriminator: - # average the z_p for the latent discriminator - z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze()) + conditional_module_loss = torch.nn.functional.l1_loss(new_m_p, z_p_avg) + + if self.args.use_pitch and not self.args.use_pitch_on_enc_input and not self.args.use_z_decoder and not self.args.use_encoder_conditional_module: + pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g) + m_p = m_p + gt_avg_pitch_emb # expand prior m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) @@ -1568,7 +1577,12 @@ class Vits(BaseTTS): z_decoder_loss = None if self.args.use_z_decoder: - x_expanded = torch.einsum("klmn, kjm -> kjn", [attn, x]) + cond_x = x + if self.args.use_pitch and not self.args.use_pitch_on_enc_input: + pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(cond_x, x_lengths, pitch, attn.sum(3), g) + cond_x = cond_x + gt_avg_pitch_emb + + x_expanded = torch.einsum("klmn, kjm -> kjn", [attn, cond_x]) # prepare the conditional emb g_dec = g if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): @@ -1646,7 +1660,7 @@ class Vits(BaseTTS): { "model_outputs": o, "alignments": attn.squeeze(1), - "m_p_unexpanded": m_p, + "m_p_unexpanded": m_p if new_m_p is None else new_m_p, "z_p_avg": z_p_avg, "m_p": m_p_expanded, "logs_p": logs_p_expanded, @@ -1774,14 +1788,14 @@ class Vits(BaseTTS): pred_avg_pitch_emb = None if self.args.use_pitch and self.args.use_pitch_on_enc_input: - _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g_dp, pitch_transform=pitch_transform) + _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g, pitch_transform=pitch_transform) x, m_p, logs_p, x_mask = self.text_encoder( x, x_lengths, lang_emb=lang_emb, emo_emb=eg if not self.args.use_noise_scale_predictor else None, - pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module else None, + pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None, pitch_emb=pred_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None, ) @@ -1811,19 +1825,22 @@ class Vits(BaseTTS): attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) - if self.args.use_pitch and not self.args.use_pitch_on_enc_input: - _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp, pitch_transform=pitch_transform) + if self.args.use_pitch and not self.args.use_pitch_on_enc_input and not self.args.use_z_decoder and not self.args.use_encoder_conditional_module: + _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g, pitch_transform=pitch_transform) m_p = m_p + pred_avg_pitch_emb if self.args.use_encoder_conditional_module: g_cond = None cond_module_input = x + if self.args.use_pitch and not self.args.use_pitch_on_enc_input: + _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(cond_module_input, x_lengths, g_pp=g, pitch_transform=pitch_transform) + cond_module_input = cond_module_input + pred_avg_pitch_emb + if self.args.use_prosody_encoder: if g_cond is None: g_cond = pros_emb else: g_cond = torch.cat([g_cond, pros_emb], dim=1) # [b, h1+h2, 1] - if g_cond is not None: cond_module_input = torch.cat((cond_module_input, g_cond.expand(-1, -1, cond_module_input.size(2))), dim=1) m_p = self.encoder_conditional_module(cond_module_input, x_mask) @@ -1850,14 +1867,20 @@ class Vits(BaseTTS): z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale if self.args.use_z_decoder: - x_expanded = torch.matmul(attn.transpose(1, 2), x.transpose(1, 2)).transpose(1, 2) + cond_x = x + if self.args.use_pitch and not self.args.use_pitch_on_enc_input: + _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(cond_x, x_lengths, g_pp=g, pitch_transform=pitch_transform) + cond_x = cond_x + pred_avg_pitch_emb + + x_expanded = torch.matmul(attn.transpose(1, 2), cond_x.transpose(1, 2)).transpose(1, 2) # prepare the conditional emb g_dec = g if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): if g_dec is None: g_dec = eg else: - g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1] + g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1]+ + if self.args.use_prosody_encoder: if g_dec is None: g_dec = pros_emb @@ -2653,4 +2676,4 @@ class VitsCharacters(BaseCharacters): blank=self._blank, is_unique=False, is_sorted=True, - ) + ) \ No newline at end of file diff --git a/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor.py b/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor.py index b2248d72..8798b864 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor.py @@ -48,12 +48,15 @@ config.model_args.alignments_cache_path = "tests/data/ljspeech/mas_alignments/al # pitch predictor config.model_args.use_pitch = True -config.model_args.use_pitch_on_enc_input = True +config.model_args.use_pitch_on_enc_input = False config.model_args.pitch_embedding_dim = 2 -config.model_args.condition_dp_on_speaker = True +config.model_args.condition_dp_on_speaker = False -config.model_args.use_latent_discriminator = True +config.model_args.use_encoder_conditional_module = True +config.model_args.use_z_decoder = False + +config.model_args.use_latent_discriminator = False config.save_json(config_path) # train the model for one epoch diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py index 922c5fbe..b01ac3c8 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -46,7 +46,8 @@ config.model_args.d_vector_dim = 128 config.model_args.use_prosody_encoder = True config.model_args.prosody_embedding_dim = 64 -config.model_args.use_encoder_conditional_module = True +config.model_args.use_z_decoder = True +config.model_args.use_encoder_conditional_module = False # active classifier config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"