diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index eb2e9976..04c74fd2 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -675,7 +675,6 @@ class VitsArgs(Coqpit): use_noise_scale_predictor: bool = False use_latent_discriminator: bool = False - use_avg_feature_on_latent_discriminator: bool = False # Pitch predictor use_pitch_on_enc_input: bool = False @@ -1493,6 +1492,11 @@ class Vits(BaseTTS): pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp) m_p = m_p + gt_avg_pitch_emb + z_p_avg = None + if self.args.use_latent_discriminator: + # average the z_p for the latent discriminator + z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze()) + # expand prior m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) @@ -1553,6 +1557,8 @@ class Vits(BaseTTS): { "model_outputs": o, "alignments": attn.squeeze(1), + "m_p_unexpanded": m_p, + "z_p_avg": z_p_avg, "m_p": m_p_expanded, "logs_p": logs_p_expanded, "z": z, @@ -1885,8 +1891,8 @@ class Vits(BaseTTS): scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _ = self.disc( outputs["model_outputs"].detach(), outputs["waveform_seg"], - outputs["m_p"].detach(), - outputs["z_p"].detach(), + outputs["m_p_unexpanded"].detach(), + outputs["z_p_avg"].detach(), ) # compute loss @@ -1933,8 +1939,8 @@ class Vits(BaseTTS): ) = self.disc( self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"], - self.model_outputs_cache["m_p"], - self.model_outputs_cache["z_p"].detach(), + self.model_outputs_cache["m_p_unexpanded"], + self.model_outputs_cache["z_p_avg"].detach(), ) # compute losses diff --git a/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor.py b/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor.py index f3fe2bd5..3cd011fe 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor.py @@ -53,6 +53,8 @@ config.model_args.pitch_embedding_dim = 2 config.model_args.condition_dp_on_speaker = True +config.model_args.use_latent_discriminator = True + config.save_json(config_path) # train the model for one epoch command_train = (