diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index e0010250..80e697c1 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -674,7 +674,6 @@ class VitsArgs(Coqpit): use_prosody_enc_spk_reversal_classifier: bool = False use_prosody_enc_emo_classifier: bool = False - use_noise_scale_predictor: bool = False use_latent_discriminator: bool = False use_encoder_conditional_module: bool = False @@ -787,8 +786,8 @@ class Vits(BaseTTS): self.args.kernel_size_text_encoder, self.args.dropout_p_text_encoder, language_emb_dim=self.embedded_language_dim, - emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_noise_scale_predictor else 0, - prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else 0, + emotion_emb_dim=self.args.emotion_embedding_dim, + prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else 0, pitch_dim=self.args.pitch_embedding_dim if self.args.use_pitch and self.args.use_pitch_on_enc_input else 0, ) @@ -824,10 +823,10 @@ class Vits(BaseTTS): self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings or self.args.use_speaker_embedding_as_emotion - ) and not self.args.use_noise_scale_predictor: + ): dp_extra_inp_dim += self.args.emotion_embedding_dim - if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder: + if self.args.use_prosody_encoder and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder: dp_extra_inp_dim += self.args.prosody_embedding_dim if self.args.use_pitch and self.args.use_pitch_on_enc_input: @@ -943,31 +942,6 @@ class Vits(BaseTTS): reversal=False, ) - if self.args.use_noise_scale_predictor: - noise_scale_predictor_input_dim = self.args.hidden_channels - if ( - self.args.use_emotion_embedding - or self.args.use_external_emotions_embeddings - or self.args.use_speaker_embedding_as_emotion - ): - noise_scale_predictor_input_dim += self.args.emotion_embedding_dim - - if self.args.use_prosody_encoder: - noise_scale_predictor_input_dim += self.args.prosody_embedding_dim - - self.noise_scale_predictor = RelativePositionTransformer( - in_channels=noise_scale_predictor_input_dim, - out_channels=self.args.hidden_channels, - hidden_channels=noise_scale_predictor_input_dim, - hidden_channels_ffn=self.args.hidden_channels_ffn_text_encoder, - num_heads=self.args.num_heads_text_encoder, - num_layers=4, - kernel_size=self.args.kernel_size_text_encoder, - dropout_p=self.args.dropout_p_text_encoder, - layer_norm_type="2", - rel_attn_window_size=4, - ) - if self.args.use_emotion_embedding_squeezer: self.emotion_embedding_squeezer = nn.Linear( in_features=self.args.emotion_embedding_squeezer_input_dim, out_features=self.args.emotion_embedding_dim @@ -1513,8 +1487,8 @@ class Vits(BaseTTS): x, x_lengths, lang_emb=lang_emb, - emo_emb=eg if not self.args.use_noise_scale_predictor else None, - pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None, + emo_emb=eg, + pros_emb=pros_emb if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None, pitch_emb=gt_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None, ) @@ -1603,22 +1577,6 @@ class Vits(BaseTTS): z_decoder = self.z_decoder(x_expanded, y_mask, g=g_dec) z_decoder_loss = torch.nn.functional.l1_loss(z_decoder * y_mask, z) - if self.args.use_noise_scale_predictor: - nsp_input = torch.transpose(m_p_expanded, 1, -1) - if self.args.use_prosody_encoder and pros_emb is not None: - nsp_input = torch.cat( - (nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1 - ) - - if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None: - nsp_input = torch.cat( - (nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1 - ) - - nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask - m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask) - m_p_expanded = m_p_expanded + m_p_noise_scale * torch.exp(logs_p_expanded) - # select a random feature segment for the waveform decoder z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) @@ -1794,8 +1752,8 @@ class Vits(BaseTTS): x, x_lengths, lang_emb=lang_emb, - emo_emb=eg if not self.args.use_noise_scale_predictor else None, - pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None, + emo_emb=eg, + pros_emb=pros_emb if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None, pitch_emb=pred_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None, ) @@ -1848,23 +1806,7 @@ class Vits(BaseTTS): m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) - if self.args.use_noise_scale_predictor: - nsp_input = torch.transpose(m_p, 1, -1) - if self.args.use_prosody_encoder and pros_emb is not None: - nsp_input = torch.cat( - (nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1 - ) - - if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None: - nsp_input = torch.cat( - (nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1 - ) - - nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask - m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask) - z_p = m_p + m_p_noise_scale * torch.exp(logs_p) - else: - z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale if self.args.use_z_decoder: cond_x = x