diff --git a/TTS/tts/layers/vits/discriminator.py b/TTS/tts/layers/vits/discriminator.py index e65a3eba..6efcc069 100644 --- a/TTS/tts/layers/vits/discriminator.py +++ b/TTS/tts/layers/vits/discriminator.py @@ -58,14 +58,14 @@ class VitsDiscriminator(nn.Module): use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. """ - def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False, use_latent_disc=False): + def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False, use_latent_disc=False, hidden_channels=None): super().__init__() self.nets = nn.ModuleList() self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm)) self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]) self.disc_latent = None if use_latent_disc: - self.disc_latent = LatentDiscriminator(use_spectral_norm=use_spectral_norm) + self.disc_latent = LatentDiscriminator(use_spectral_norm=use_spectral_norm, hidden_channels=hidden_channels) def forward(self, x, x_hat=None, m_p=None, z_p=None): """ @@ -104,12 +104,13 @@ class VitsDiscriminator(nn.Module): class LatentDiscriminator(nn.Module): """Discriminator with the same architecture as the Univnet SpecDiscriminator""" - def __init__(self, use_spectral_norm=False): + def __init__(self, use_spectral_norm=False, hidden_channels=None): super().__init__() norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm + self.hidden_channels = hidden_channels self.discriminators = nn.ModuleList( [ - norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))), + norm_f(nn.Conv2d(1 if hidden_channels is None else hidden_channels, 32, kernel_size=(3, 9), padding=(1, 4))), norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), @@ -121,6 +122,8 @@ class LatentDiscriminator(nn.Module): def forward(self, y): fmap = [] + if self.hidden_channels is not None: + y = y.squeeze(1).unsqueeze(-1) for _, d in enumerate(self.discriminators): y = d(y) y = torch.nn.functional.leaky_relu(y, 0.1) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 995467e5..bc96d0c3 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -22,6 +22,7 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE from TTS.tts.layers.vits.discriminator import VitsDiscriminator, LatentDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS @@ -556,6 +557,7 @@ class VitsArgs(Coqpit): use_prosody_enc_spk_reversal_classifier: bool = False use_prosody_enc_emo_classifier: bool = False + use_noise_scale_predictor: bool = False use_prosody_conditional_flow_module: bool = False prosody_conditional_flow_module_on_decoder: bool = False @@ -565,6 +567,7 @@ class VitsArgs(Coqpit): use_soft_dtw: bool = False use_latent_discriminator: bool = False + provide_hidden_dim_on_the_latent_discriminator: bool = False detach_dp_input: bool = True use_language_embedding: bool = False @@ -650,8 +653,8 @@ class Vits(BaseTTS): self.args.kernel_size_text_encoder, self.args.dropout_p_text_encoder, language_emb_dim=self.embedded_language_dim, - emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_prosody_conditional_flow_module else 0, - prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_prosody_conditional_flow_module else 0, + emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor else 0, + prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor else 0, ) self.posterior_encoder = PosteriorEncoder( @@ -682,10 +685,10 @@ class Vits(BaseTTS): dp_cond_embedding_dim += self.args.prosody_embedding_dim dp_extra_inp_dim = 0 - if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.use_prosody_conditional_flow_module: + if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor: dp_extra_inp_dim += self.args.emotion_embedding_dim - if self.args.use_prosody_encoder and not self.args.use_prosody_conditional_flow_module: + if self.args.use_prosody_encoder and not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor: dp_extra_inp_dim += self.args.prosody_embedding_dim if self.args.use_sdp: @@ -756,6 +759,27 @@ class Vits(BaseTTS): cond_channels=cond_embedding_dim, ) + if self.args.use_noise_scale_predictor: + noise_scale_predictor_input_dim = self.args.hidden_channels + if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings: + noise_scale_predictor_input_dim += self.args.emotion_embedding_dim + + if self.args.use_prosody_encoder: + noise_scale_predictor_input_dim += self.args.prosody_embedding_dim + + self.noise_scale_predictor = RelativePositionTransformer( + in_channels=noise_scale_predictor_input_dim, + out_channels=self.args.hidden_channels, + hidden_channels=noise_scale_predictor_input_dim, + hidden_channels_ffn=self.args.hidden_channels_ffn_text_encoder, + num_heads=self.args.num_heads_text_encoder, + num_layers=4, + kernel_size=self.args.kernel_size_text_encoder, + dropout_p=self.args.dropout_p_text_encoder, + layer_norm_type="2", + rel_attn_window_size=4, + ) + if self.args.use_text_enc_spk_reversal_classifier: self.speaker_text_enc_reversal_classifier = ReversalClassifier( in_channels=self.args.hidden_channels @@ -793,6 +817,7 @@ class Vits(BaseTTS): periods=self.args.periods_multi_period_discriminator, use_spectral_norm=self.args.use_spectral_norm_disriminator, use_latent_disc=self.args.use_latent_discriminator, + hidden_channels=self.args.hidden_channels if self.args.provide_hidden_dim_on_the_latent_discriminator else None, ) def init_multispeaker(self, config: Coqpit): @@ -927,7 +952,7 @@ class Vits(BaseTTS): if value == before_dict[key]: raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") print(" > Text Encoder was reinit.") - + def init_emotion(self, emotion_manager: EmotionManager): # pylint: disable=attribute-defined-outside-init """Initialize emotion modules of a model. A model can be trained either with a emotion embedding layer @@ -1203,8 +1228,8 @@ class Vits(BaseTTS): x, x_lengths, lang_emb=lang_emb, - emo_emb=eg if not self.args.use_prosody_conditional_flow_module else None, - pros_emb=pros_emb if not self.args.use_prosody_conditional_flow_module else None + emo_emb=eg if not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor else None, + pros_emb=pros_emb if not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor else None ) # reversal speaker loss to force the encoder to be speaker identity free @@ -1249,6 +1274,17 @@ class Vits(BaseTTS): # expand prior m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) + if self.args.use_noise_scale_predictor: + nsp_input = torch.transpose(m_p_expanded, 1, -1) + if self.args.use_prosody_encoder and pros_emb is not None: + nsp_input = torch.cat((nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1) + + if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None: + nsp_input = torch.cat((nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1) + + nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask + m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask) + m_p_expanded = m_p_expanded + m_p_noise_scale * torch.exp(logs_p_expanded) # select a random feature segment for the waveform decoder z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) @@ -1424,8 +1460,8 @@ class Vits(BaseTTS): x, x_lengths, lang_emb=lang_emb, - emo_emb=eg if not self.args.use_prosody_conditional_flow_module else None, - pros_emb=pros_emb if not self.args.use_prosody_conditional_flow_module else None + emo_emb=eg if not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor else None, + pros_emb=pros_emb if not self.args.use_prosody_conditional_flow_module and not self.args.use_noise_scale_predictor else None ) # conditional module @@ -1469,8 +1505,20 @@ class Vits(BaseTTS): m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) + + if self.args.use_noise_scale_predictor: + nsp_input = torch.transpose(m_p, 1, -1) + if self.args.use_prosody_encoder and pros_emb is not None: + nsp_input = torch.cat((nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1) - z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale + if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None: + nsp_input = torch.cat((nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1) + + nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask + m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask) + z_p = m_p + m_p_noise_scale * torch.exp(logs_p) + else: + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale # conditional module if self.args.use_prosody_conditional_flow_module: diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py index aa47a2aa..4a160fa5 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -49,10 +49,13 @@ config.model_args.use_prosody_enc_emo_classifier = False config.model_args.use_text_enc_emo_classifier = False config.model_args.use_prosody_encoder_z_p_input = True -config.model_args.prosody_encoder_type = "vae" +config.model_args.prosody_encoder_type = "gst" config.model_args.detach_prosody_enc_input = True -config.model_args.use_latent_discriminator = False +config.model_args.use_latent_discriminator = True +config.model_args.provide_hidden_dim_on_the_latent_discriminator = True +config.model_args.use_noise_scale_predictor = False + # enable end2end loss config.model_args.use_end2end_loss = False