From 135363a6d127a231e1af6aa2e8e958d0e625f663 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 22 Jun 2022 17:28:25 -0300 Subject: [PATCH] Remove Pitch conditioning on encoder --- TTS/tts/layers/vits/networks.py | 11 +- TTS/tts/layers/vits/prosody_encoder.py | 2 +- TTS/tts/models/vits.py | 270 ++++++------------ ...er_emb_with_pitch_predictor+prosody_enc.py | 98 +++++++ ...t_vits_speaker_emb_with_pitch_predictor.py | 5 +- 5 files changed, 195 insertions(+), 191 deletions(-) create mode 100644 tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor+prosody_enc.py diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index dbcb7313..c7f952e1 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -39,8 +39,7 @@ class TextEncoder(nn.Module): dropout_p: float, language_emb_dim: int = None, emotion_emb_dim: int = None, - prosody_emb_dim: int = None, - pitch_dim: int = None, + prosody_emb_dim: int = None ): """Text Encoder for VITS model. @@ -71,9 +70,6 @@ class TextEncoder(nn.Module): if prosody_emb_dim: hidden_channels += prosody_emb_dim - if pitch_dim: - hidden_channels += pitch_dim - self.encoder = RelativePositionTransformer( in_channels=hidden_channels, out_channels=hidden_channels, @@ -89,7 +85,7 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, x, x_lengths, lang_emb=None, emo_emb=None, pros_emb=None, pitch_emb=None): + def forward(self, x, x_lengths, lang_emb=None, emo_emb=None, pros_emb=None): """ Shapes: - x: :math:`[B, T]` @@ -109,9 +105,6 @@ class TextEncoder(nn.Module): if pros_emb is not None: x = torch.cat((x, pros_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) - if pitch_emb is not None: - x = torch.cat((x, pitch_emb.transpose(2, 1)), dim=-1) - x = torch.transpose(x, 1, -1) # [b, h, t] x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t] diff --git a/TTS/tts/layers/vits/prosody_encoder.py b/TTS/tts/layers/vits/prosody_encoder.py index bd7f3f33..cbc46d9f 100644 --- a/TTS/tts/layers/vits/prosody_encoder.py +++ b/TTS/tts/layers/vits/prosody_encoder.py @@ -30,5 +30,5 @@ class ResNetProsodyEncoder(ResNetSpeakerEncoder): super().__init__(*args, **kwargs) def forward(self, inputs, input_lengths=None, speaker_embedding=None): - style_embed = super().forward(inputs, l2_norm=True).unsqueeze(1) + style_embed = super().forward(inputs, l2_norm=False).unsqueeze(1) return style_embed, None \ No newline at end of file diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 80e697c1..3d996d28 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -662,7 +662,6 @@ class VitsArgs(Coqpit): # prosody encoder use_prosody_encoder: bool = False - use_pros_enc_input_as_pros_emb: bool = False prosody_encoder_type: str = "gst" detach_prosody_enc_input: bool = False condition_pros_enc_on_speaker: bool = False @@ -683,7 +682,6 @@ class VitsArgs(Coqpit): ) # Pitch predictor - use_pitch_on_enc_input: bool = False use_pitch: bool = False pitch_predictor_hidden_channels: int = 256 pitch_predictor_kernel_size: int = 3 @@ -692,7 +690,6 @@ class VitsArgs(Coqpit): detach_pp_input: bool = False use_precomputed_alignments: bool = False alignments_cache_path: str = "" - pitch_embedding_dim: int = 0 pitch_mean: float = 0.0 pitch_std: float = 0.0 @@ -787,8 +784,7 @@ class Vits(BaseTTS): self.args.dropout_p_text_encoder, language_emb_dim=self.embedded_language_dim, emotion_emb_dim=self.args.emotion_embedding_dim, - prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else 0, - pitch_dim=self.args.pitch_embedding_dim if self.args.use_pitch and self.args.use_pitch_on_enc_input else 0, + prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else 0 ) self.posterior_encoder = PosteriorEncoder( @@ -815,7 +811,7 @@ class Vits(BaseTTS): if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings: dp_cond_embedding_dim += self.args.emotion_embedding_dim - if self.args.use_prosody_encoder: + if self.args.use_prosody_encoder and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder: dp_cond_embedding_dim += self.args.prosody_embedding_dim dp_extra_inp_dim = 0 @@ -829,9 +825,6 @@ class Vits(BaseTTS): if self.args.use_prosody_encoder and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder: dp_extra_inp_dim += self.args.prosody_embedding_dim - if self.args.use_pitch and self.args.use_pitch_on_enc_input: - dp_extra_inp_dim += self.args.pitch_embedding_dim - if self.args.use_sdp: self.duration_predictor = StochasticDurationPredictor( self.args.hidden_channels + dp_extra_inp_dim, @@ -853,40 +846,30 @@ class Vits(BaseTTS): ) if self.args.use_z_decoder: - dec_extra_inp_dim = self.cond_embedding_dim - if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings: - dec_extra_inp_dim += self.args.emotion_embedding_dim - - if self.args.use_prosody_encoder: - dec_extra_inp_dim += self.args.prosody_embedding_dim - self.z_decoder = forwardDecoder( self.args.hidden_channels, - self.args.hidden_channels + dec_extra_inp_dim, + self.args.hidden_channels + self.cond_embedding_dim, self.args.z_decoder_type, self.args.z_decoder_params, ) - if self.args.use_encoder_conditional_module: - extra_inp_dim = 0 - if self.args.use_prosody_encoder: - extra_inp_dim += self.args.prosody_embedding_dim - self.encoder_conditional_module = forwardDecoder( self.args.hidden_channels, - self.args.hidden_channels + extra_inp_dim, + self.args.hidden_channels, self.args.conditional_module_type, self.args.conditional_module_params, ) if self.args.use_pitch: - if self.args.use_pitch_on_enc_input: - self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels) + if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder: + raise RuntimeError( + f" [!] use_pitch True is useless when use_encoder_conditional_module and use_z_decoder is False. Please active on of this conditional modules !!" + ) self.pitch_emb = nn.Conv1d( 1, - self.args.hidden_channels if not self.args.use_pitch_on_enc_input else self.args.pitch_embedding_dim, + self.args.hidden_channels, kernel_size=self.args.pitch_predictor_kernel_size, padding=int((self.args.pitch_predictor_kernel_size - 1) / 2), ) @@ -900,34 +883,37 @@ class Vits(BaseTTS): ) if self.args.use_prosody_encoder: - if self.args.use_pros_enc_input_as_pros_emb: - self.prosody_embedding_squeezer = nn.Linear( - in_features=self.args.hidden_channels, out_features=self.args.prosody_embedding_dim + prosody_embedding_dim = self.args.prosody_embedding_dim if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else self.args.hidden_channels + if self.args.prosody_encoder_type == "gst": + self.prosody_encoder = VitsGST( + num_mel=self.args.hidden_channels, + num_heads=self.args.prosody_encoder_num_heads, + num_style_tokens=self.args.prosody_encoder_num_tokens, + gst_embedding_dim=prosody_embedding_dim, + embedded_speaker_dim=self.cond_embedding_dim if self.args.condition_pros_enc_on_speaker else None, ) + elif self.args.prosody_encoder_type == "vae": + self.prosody_encoder = VitsVAE( + num_mel=self.args.hidden_channels, + capacitron_VAE_embedding_dim=prosody_embedding_dim, + speaker_embedding_dim=self.cond_embedding_dim if self.args.condition_pros_enc_on_speaker else None, + ) + elif self.args.prosody_encoder_type == "resnet": + self.prosody_encoder = ResNetProsodyEncoder( + input_dim=self.args.hidden_channels, + proj_dim=prosody_embedding_dim, + layers=[1, 2, 2, 1], + num_filters=[8, 16, 32, 64], + encoder_type="ASP", + ) + else: - if self.args.prosody_encoder_type == "gst": - self.prosody_encoder = VitsGST( - num_mel=self.args.hidden_channels, - num_heads=self.args.prosody_encoder_num_heads, - num_style_tokens=self.args.prosody_encoder_num_tokens, - gst_embedding_dim=self.args.prosody_embedding_dim, - embedded_speaker_dim=self.cond_embedding_dim if self.args.condition_pros_enc_on_speaker else None, - ) - elif self.args.prosody_encoder_type == "vae": - self.prosody_encoder = VitsVAE( - num_mel=self.args.hidden_channels, - capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim, - speaker_embedding_dim=self.cond_embedding_dim if self.args.condition_pros_enc_on_speaker else None, - ) - elif self.args.prosody_encoder_type == "resnet": - self.prosody_encoder = ResNetProsodyEncoder( - input_dim=self.args.hidden_channels, - proj_dim=self.args.prosody_embedding_dim, - ) - else: - raise RuntimeError( - f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!" - ) + raise RuntimeError( + f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!" + ) + + print(f" > Using the prosody Encoder type {self.args.prosody_encoder_type} with {len(list(self.prosody_encoder.parameters()))} trainable parameters !") + if self.args.use_prosody_enc_spk_reversal_classifier: self.speaker_reversal_classifier = ReversalClassifier( in_channels=self.args.prosody_embedding_dim, @@ -1274,9 +1260,6 @@ class Vits(BaseTTS): - pitch: :math:`(B, 1, T_{de})` - dr: :math:`(B, T_{en})` """ - if self.args.use_pitch_on_enc_input: - o_en = self.pitch_predictor_vocab_emb(o_en) - o_en = torch.transpose(o_en, 1, -1) # [b, h, t] x_mask = torch.unsqueeze(sequence_mask(x_lengths, o_en.size(2)), 1).to(o_en.dtype) # [b, 1, t] @@ -1445,13 +1428,6 @@ class Vits(BaseTTS): g_dp = eg else: g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1] - - pitch_loss = None - gt_avg_pitch_emb = None - if self.args.use_pitch and self.args.use_pitch_on_enc_input: - if alignments is None: - raise RuntimeError(" [!] For condition the pitch on the Text Encoder you need to provide external alignments !") - pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(x, x_lengths, pitch, alignments.sum(3), g) # posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) @@ -1466,15 +1442,11 @@ class Vits(BaseTTS): l_pros_emotion = None if self.args.use_prosody_encoder: prosody_encoder_input = z_p if self.args.use_prosody_encoder_z_p_input else z - if not self.args.use_pros_enc_input_as_pros_emb: - pros_emb, vae_outputs = self.prosody_encoder( - prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input, - y_lengths, - speaker_embedding=g if self.args.condition_pros_enc_on_speaker else None - ) - else: - pros_emb = prosody_encoder_input.mean(2).unsqueeze(1).detach() - pros_emb = F.normalize(self.prosody_embedding_squeezer(pros_emb.squeeze(1))).unsqueeze(1) + pros_emb, vae_outputs = self.prosody_encoder( + prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input, + y_lengths, + speaker_embedding=g if self.args.condition_pros_enc_on_speaker else None + ) pros_emb = pros_emb.transpose(1, 2) @@ -1488,8 +1460,7 @@ class Vits(BaseTTS): x_lengths, lang_emb=lang_emb, emo_emb=eg, - pros_emb=pros_emb if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None, - pitch_emb=gt_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None, + pros_emb=pros_emb if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None ) # reversal speaker loss to force the encoder to be speaker identity free @@ -1502,7 +1473,12 @@ class Vits(BaseTTS): if self.args.use_text_enc_emo_classifier: _, l_text_emotion = self.emo_text_enc_classifier(m_p.transpose(1, 2), eid, x_mask=x_mask) - if self.args.use_prosody_encoder: + # add prosody embedding on x if needed + if self.args.use_prosody_encoder and (self.args.use_encoder_conditional_module or self.args.use_z_decoder): + x = x + pros_emb.expand(-1, -1, x.size(2)) + + # add prosody embedding when necessary + if self.args.use_prosody_encoder and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder: if g_dp is None: g_dp = pros_emb else: @@ -1510,30 +1486,24 @@ class Vits(BaseTTS): outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb) + # add pitch + pitch_loss = None + gt_avg_pitch_emb = None + if self.args.use_pitch: + pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(x, x_lengths, pitch, attn.sum(3), g) + x = x + gt_avg_pitch_emb + print(gt_avg_pitch_emb.shape, x.shape) + z_p_avg = None if self.args.use_latent_discriminator: # average the z_p for the latent discriminator z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze()) + # conditional module conditional_module_loss = None new_m_p = None if self.args.use_encoder_conditional_module: - g_cond = None - cond_module_input = x - if self.args.use_pitch and not self.args.use_pitch_on_enc_input: - pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(cond_module_input, x_lengths, pitch, attn.sum(3), g) - cond_module_input = cond_module_input + gt_avg_pitch_emb - - if self.args.use_prosody_encoder: - if g_cond is None: - g_cond = pros_emb - else: - g_cond = torch.cat([g_cond, pros_emb], dim=1) # [b, h1+h2, 1] - - if g_cond is not None: - cond_module_input = torch.cat((cond_module_input, g_cond.expand(-1, -1, cond_module_input.size(2))), dim=1) - - new_m_p = self.encoder_conditional_module(cond_module_input, x_mask) * x_mask + new_m_p = self.encoder_conditional_module(x, x_mask) * x_mask if z_p_avg is None: z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze()).detach() else: @@ -1541,40 +1511,20 @@ class Vits(BaseTTS): conditional_module_loss = torch.nn.functional.l1_loss(new_m_p, z_p_avg) - if self.args.use_pitch and not self.args.use_pitch_on_enc_input and not self.args.use_z_decoder and not self.args.use_encoder_conditional_module: - pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g) - m_p = m_p + gt_avg_pitch_emb - # expand prior m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) + # z decoder z_decoder_loss = None if self.args.use_z_decoder: - cond_x = x - if self.args.use_pitch and not self.args.use_pitch_on_enc_input: - pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(cond_x, x_lengths, pitch, attn.sum(3), g) - cond_x = cond_x + gt_avg_pitch_emb - - x_expanded = torch.einsum("klmn, kjm -> kjn", [attn, cond_x]) - # prepare the conditional emb - g_dec = g - if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): - if g_dec is None: - g_dec = eg - else: - g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1] - if self.args.use_prosody_encoder: - if g_dec is None: - g_dec = pros_emb - else: - g_dec = torch.cat([g_dec, pros_emb], dim=1) # [b, h1+h2, 1] - - if g_dec is not None: - x_expanded = torch.cat((x_expanded, g_dec.expand(-1, -1, x_expanded.size(2))), dim=1) + dec_input = torch.einsum("klmn, kjm -> kjn", [attn, x]) + # add speaker emb + if g is not None: + dec_input = torch.cat((dec_input, g.expand(-1, -1, dec_input.size(2))), dim=1) # decoder pass - z_decoder = self.z_decoder(x_expanded, y_mask, g=g_dec) + z_decoder = self.z_decoder(dec_input, y_mask, g=None) z_decoder_loss = torch.nn.functional.l1_loss(z_decoder * y_mask, z) # select a random feature segment for the waveform decoder @@ -1723,19 +1673,26 @@ class Vits(BaseTTS): # extract posterior encoder feature pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device) z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=ssg) - if not self.args.use_pros_enc_input_as_pros_emb: - if not self.args.use_prosody_encoder_z_p_input: - pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths, speaker_embedding=ssg if self.args.condition_pros_enc_on_speaker else None) - else: - z_p_inf = self.flow(z_pro, z_pro_y_mask, g=ssg) - pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths, speaker_embedding=ssg if self.args.condition_pros_enc_on_speaker else None) + if not self.args.use_prosody_encoder_z_p_input: + pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths, speaker_embedding=ssg if self.args.condition_pros_enc_on_speaker else None) else: - prosody_encoder_input = self.flow(z_pro, z_pro_y_mask, g=ssg) if self.args.use_prosody_encoder_z_p_input else z_pro - pros_emb = prosody_encoder_input.mean(2).unsqueeze(1) - pros_emb = F.normalize(self.prosody_embedding_squeezer(pros_emb.squeeze(1))).unsqueeze(1) + z_p_inf = self.flow(z_pro, z_pro_y_mask, g=ssg) + pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths, speaker_embedding=ssg if self.args.condition_pros_enc_on_speaker else None) pros_emb = pros_emb.transpose(1, 2) + x, m_p, logs_p, x_mask = self.text_encoder( + x, + x_lengths, + lang_emb=lang_emb, + emo_emb=eg, + pros_emb=pros_emb if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None + ) + + # add prosody embedding on x if needed + if self.args.use_prosody_encoder and (self.args.use_encoder_conditional_module or self.args.use_z_decoder): + x = x + pros_emb.expand(-1, -1, x.size(2)) + # duration predictor g_dp = g if self.args.condition_dp_on_speaker else None if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): @@ -1744,20 +1701,7 @@ class Vits(BaseTTS): else: g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1] - pred_avg_pitch_emb = None - if self.args.use_pitch and self.args.use_pitch_on_enc_input: - _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g, pitch_transform=pitch_transform) - - x, m_p, logs_p, x_mask = self.text_encoder( - x, - x_lengths, - lang_emb=lang_emb, - emo_emb=eg, - pros_emb=pros_emb if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None, - pitch_emb=pred_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None, - ) - - if self.args.use_prosody_encoder: + if self.args.use_prosody_encoder and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder: if g_dp is None: g_dp = pros_emb else: @@ -1783,25 +1727,12 @@ class Vits(BaseTTS): attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) - if self.args.use_pitch and not self.args.use_pitch_on_enc_input and not self.args.use_z_decoder and not self.args.use_encoder_conditional_module: - _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g, pitch_transform=pitch_transform) - m_p = m_p + pred_avg_pitch_emb + if self.args.use_pitch: + _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g, pitch_transform=pitch_transform) + x = x + pred_avg_pitch_emb if self.args.use_encoder_conditional_module: - g_cond = None - cond_module_input = x - if self.args.use_pitch and not self.args.use_pitch_on_enc_input: - _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(cond_module_input, x_lengths, g_pp=g, pitch_transform=pitch_transform) - cond_module_input = cond_module_input + pred_avg_pitch_emb - - if self.args.use_prosody_encoder: - if g_cond is None: - g_cond = pros_emb - else: - g_cond = torch.cat([g_cond, pros_emb], dim=1) # [b, h1+h2, 1] - if g_cond is not None: - cond_module_input = torch.cat((cond_module_input, g_cond.expand(-1, -1, cond_module_input.size(2))), dim=1) - m_p = self.encoder_conditional_module(cond_module_input, x_mask) + m_p = self.encoder_conditional_module(x, x_mask) m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) @@ -1809,31 +1740,14 @@ class Vits(BaseTTS): z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale if self.args.use_z_decoder: - cond_x = x - if self.args.use_pitch and not self.args.use_pitch_on_enc_input: - _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(cond_x, x_lengths, g_pp=g, pitch_transform=pitch_transform) - cond_x = cond_x + pred_avg_pitch_emb + dec_input = torch.matmul(attn.transpose(1, 2), x.transpose(1, 2)).transpose(1, 2) - x_expanded = torch.matmul(attn.transpose(1, 2), cond_x.transpose(1, 2)).transpose(1, 2) - # prepare the conditional emb - g_dec = g - if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): - if g_dec is None: - g_dec = eg - else: - g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1]+ - - if self.args.use_prosody_encoder: - if g_dec is None: - g_dec = pros_emb - else: - g_dec = torch.cat([g_dec, pros_emb], dim=1) # [b, h1+h2, 1] - - if g_dec is not None: - x_expanded = torch.cat((x_expanded, g_dec.expand(-1, -1, x_expanded.size(2))), dim=1) + # add speaker emb + if g is not None: + dec_input = torch.cat((dec_input, g.expand(-1, -1, dec_input.size(2))), dim=1) # decoder pass - z = self.z_decoder(x_expanded, y_mask, g=g_dec) + z = self.z_decoder(dec_input, y_mask, g=None) else: z = self.flow(z_p, y_mask, g=g, reverse=True) diff --git a/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor+prosody_enc.py b/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor+prosody_enc.py new file mode 100644 index 00000000..24ddac6a --- /dev/null +++ b/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor+prosody_enc.py @@ -0,0 +1,98 @@ +import glob +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.vits_config import VitsConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = VitsConfig( + batch_size=2, + eval_batch_size=2, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + compute_pitch=True, + f0_cache_path="tests/data/ljspeech/f0_cache/", + test_sentences=[ + ["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None, "ljspeech-2"], + ], +) +# set audio config +config.audio.do_trim_silence = True +config.audio.trim_db = 60 + +# active multispeaker d-vec mode +config.model_args.use_speaker_embedding = True +config.model_args.use_d_vector_file = False +config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" +config.model_args.speaker_embedding_channels = 128 +config.model_args.d_vector_dim = 128 + + +config.model_args.use_precomputed_alignments = True +config.model_args.alignments_cache_path = "tests/data/ljspeech/mas_alignments/alignments/" + +# pitch predictor +config.model_args.use_pitch = True +config.model_args.use_pitch_on_enc_input = False +config.model_args.pitch_embedding_dim = 2 +config.model_args.condition_dp_on_speaker = False + + + +# prosody encoder +config.model_args.use_prosody_encoder = True +config.model_args.prosody_embedding_dim = 64 +config.model_args.prosody_encoder_type = "resnet" + +config.model_args.use_encoder_conditional_module = True +config.model_args.use_z_decoder = False + +config.model_args.use_latent_discriminator = False + +config.save_json(config_path) +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech_test " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +continue_speakers_path = os.path.join(continue_path, "speakers.json") + + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} " +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) diff --git a/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor.py b/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor.py index 8798b864..b9f10226 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor.py @@ -52,9 +52,8 @@ config.model_args.use_pitch_on_enc_input = False config.model_args.pitch_embedding_dim = 2 config.model_args.condition_dp_on_speaker = False - -config.model_args.use_encoder_conditional_module = True -config.model_args.use_z_decoder = False +config.model_args.use_encoder_conditional_module = False +config.model_args.use_z_decoder = True config.model_args.use_latent_discriminator = False