mirror of https://github.com/coqui-ai/TTS.git
Remove Pitch conditioning on encoder
This commit is contained in:
parent
25e3221daf
commit
85ce5aee34
|
@ -39,8 +39,7 @@ class TextEncoder(nn.Module):
|
|||
dropout_p: float,
|
||||
language_emb_dim: int = None,
|
||||
emotion_emb_dim: int = None,
|
||||
prosody_emb_dim: int = None,
|
||||
pitch_dim: int = None,
|
||||
prosody_emb_dim: int = None
|
||||
):
|
||||
"""Text Encoder for VITS model.
|
||||
|
||||
|
@ -71,9 +70,6 @@ class TextEncoder(nn.Module):
|
|||
if prosody_emb_dim:
|
||||
hidden_channels += prosody_emb_dim
|
||||
|
||||
if pitch_dim:
|
||||
hidden_channels += pitch_dim
|
||||
|
||||
self.encoder = RelativePositionTransformer(
|
||||
in_channels=hidden_channels,
|
||||
out_channels=hidden_channels,
|
||||
|
@ -89,7 +85,7 @@ class TextEncoder(nn.Module):
|
|||
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, lang_emb=None, emo_emb=None, pros_emb=None, pitch_emb=None):
|
||||
def forward(self, x, x_lengths, lang_emb=None, emo_emb=None, pros_emb=None):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, T]`
|
||||
|
@ -109,9 +105,6 @@ class TextEncoder(nn.Module):
|
|||
if pros_emb is not None:
|
||||
x = torch.cat((x, pros_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
|
||||
|
||||
if pitch_emb is not None:
|
||||
x = torch.cat((x, pitch_emb.transpose(2, 1)), dim=-1)
|
||||
|
||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t]
|
||||
|
||||
|
|
|
@ -30,5 +30,5 @@ class ResNetProsodyEncoder(ResNetSpeakerEncoder):
|
|||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, inputs, input_lengths=None, speaker_embedding=None):
|
||||
style_embed = super().forward(inputs, l2_norm=True).unsqueeze(1)
|
||||
style_embed = super().forward(inputs, l2_norm=False).unsqueeze(1)
|
||||
return style_embed, None
|
|
@ -662,7 +662,6 @@ class VitsArgs(Coqpit):
|
|||
|
||||
# prosody encoder
|
||||
use_prosody_encoder: bool = False
|
||||
use_pros_enc_input_as_pros_emb: bool = False
|
||||
prosody_encoder_type: str = "gst"
|
||||
detach_prosody_enc_input: bool = False
|
||||
condition_pros_enc_on_speaker: bool = False
|
||||
|
@ -683,7 +682,6 @@ class VitsArgs(Coqpit):
|
|||
)
|
||||
|
||||
# Pitch predictor
|
||||
use_pitch_on_enc_input: bool = False
|
||||
use_pitch: bool = False
|
||||
pitch_predictor_hidden_channels: int = 256
|
||||
pitch_predictor_kernel_size: int = 3
|
||||
|
@ -692,7 +690,6 @@ class VitsArgs(Coqpit):
|
|||
detach_pp_input: bool = False
|
||||
use_precomputed_alignments: bool = False
|
||||
alignments_cache_path: str = ""
|
||||
pitch_embedding_dim: int = 0
|
||||
pitch_mean: float = 0.0
|
||||
pitch_std: float = 0.0
|
||||
|
||||
|
@ -787,8 +784,7 @@ class Vits(BaseTTS):
|
|||
self.args.dropout_p_text_encoder,
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
emotion_emb_dim=self.args.emotion_embedding_dim,
|
||||
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else 0,
|
||||
pitch_dim=self.args.pitch_embedding_dim if self.args.use_pitch and self.args.use_pitch_on_enc_input else 0,
|
||||
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else 0
|
||||
)
|
||||
|
||||
self.posterior_encoder = PosteriorEncoder(
|
||||
|
@ -815,7 +811,7 @@ class Vits(BaseTTS):
|
|||
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
|
||||
dp_cond_embedding_dim += self.args.emotion_embedding_dim
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
if self.args.use_prosody_encoder and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder:
|
||||
dp_cond_embedding_dim += self.args.prosody_embedding_dim
|
||||
|
||||
dp_extra_inp_dim = 0
|
||||
|
@ -829,9 +825,6 @@ class Vits(BaseTTS):
|
|||
if self.args.use_prosody_encoder and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder:
|
||||
dp_extra_inp_dim += self.args.prosody_embedding_dim
|
||||
|
||||
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
||||
dp_extra_inp_dim += self.args.pitch_embedding_dim
|
||||
|
||||
if self.args.use_sdp:
|
||||
self.duration_predictor = StochasticDurationPredictor(
|
||||
self.args.hidden_channels + dp_extra_inp_dim,
|
||||
|
@ -853,40 +846,30 @@ class Vits(BaseTTS):
|
|||
)
|
||||
|
||||
if self.args.use_z_decoder:
|
||||
dec_extra_inp_dim = self.cond_embedding_dim
|
||||
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
|
||||
dec_extra_inp_dim += self.args.emotion_embedding_dim
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
dec_extra_inp_dim += self.args.prosody_embedding_dim
|
||||
|
||||
self.z_decoder = forwardDecoder(
|
||||
self.args.hidden_channels,
|
||||
self.args.hidden_channels + dec_extra_inp_dim,
|
||||
self.args.hidden_channels + self.cond_embedding_dim,
|
||||
self.args.z_decoder_type,
|
||||
self.args.z_decoder_params,
|
||||
)
|
||||
|
||||
|
||||
if self.args.use_encoder_conditional_module:
|
||||
extra_inp_dim = 0
|
||||
if self.args.use_prosody_encoder:
|
||||
extra_inp_dim += self.args.prosody_embedding_dim
|
||||
|
||||
self.encoder_conditional_module = forwardDecoder(
|
||||
self.args.hidden_channels,
|
||||
self.args.hidden_channels + extra_inp_dim,
|
||||
self.args.hidden_channels,
|
||||
self.args.conditional_module_type,
|
||||
self.args.conditional_module_params,
|
||||
)
|
||||
|
||||
if self.args.use_pitch:
|
||||
if self.args.use_pitch_on_enc_input:
|
||||
self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
|
||||
if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder:
|
||||
raise RuntimeError(
|
||||
f" [!] use_pitch True is useless when use_encoder_conditional_module and use_z_decoder is False. Please active on of this conditional modules !!"
|
||||
)
|
||||
|
||||
self.pitch_emb = nn.Conv1d(
|
||||
1,
|
||||
self.args.hidden_channels if not self.args.use_pitch_on_enc_input else self.args.pitch_embedding_dim,
|
||||
self.args.hidden_channels,
|
||||
kernel_size=self.args.pitch_predictor_kernel_size,
|
||||
padding=int((self.args.pitch_predictor_kernel_size - 1) / 2),
|
||||
)
|
||||
|
@ -900,34 +883,37 @@ class Vits(BaseTTS):
|
|||
)
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
if self.args.use_pros_enc_input_as_pros_emb:
|
||||
self.prosody_embedding_squeezer = nn.Linear(
|
||||
in_features=self.args.hidden_channels, out_features=self.args.prosody_embedding_dim
|
||||
prosody_embedding_dim = self.args.prosody_embedding_dim if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else self.args.hidden_channels
|
||||
if self.args.prosody_encoder_type == "gst":
|
||||
self.prosody_encoder = VitsGST(
|
||||
num_mel=self.args.hidden_channels,
|
||||
num_heads=self.args.prosody_encoder_num_heads,
|
||||
num_style_tokens=self.args.prosody_encoder_num_tokens,
|
||||
gst_embedding_dim=prosody_embedding_dim,
|
||||
embedded_speaker_dim=self.cond_embedding_dim if self.args.condition_pros_enc_on_speaker else None,
|
||||
)
|
||||
elif self.args.prosody_encoder_type == "vae":
|
||||
self.prosody_encoder = VitsVAE(
|
||||
num_mel=self.args.hidden_channels,
|
||||
capacitron_VAE_embedding_dim=prosody_embedding_dim,
|
||||
speaker_embedding_dim=self.cond_embedding_dim if self.args.condition_pros_enc_on_speaker else None,
|
||||
)
|
||||
elif self.args.prosody_encoder_type == "resnet":
|
||||
self.prosody_encoder = ResNetProsodyEncoder(
|
||||
input_dim=self.args.hidden_channels,
|
||||
proj_dim=prosody_embedding_dim,
|
||||
layers=[1, 2, 2, 1],
|
||||
num_filters=[8, 16, 32, 64],
|
||||
encoder_type="ASP",
|
||||
)
|
||||
|
||||
else:
|
||||
if self.args.prosody_encoder_type == "gst":
|
||||
self.prosody_encoder = VitsGST(
|
||||
num_mel=self.args.hidden_channels,
|
||||
num_heads=self.args.prosody_encoder_num_heads,
|
||||
num_style_tokens=self.args.prosody_encoder_num_tokens,
|
||||
gst_embedding_dim=self.args.prosody_embedding_dim,
|
||||
embedded_speaker_dim=self.cond_embedding_dim if self.args.condition_pros_enc_on_speaker else None,
|
||||
)
|
||||
elif self.args.prosody_encoder_type == "vae":
|
||||
self.prosody_encoder = VitsVAE(
|
||||
num_mel=self.args.hidden_channels,
|
||||
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
|
||||
speaker_embedding_dim=self.cond_embedding_dim if self.args.condition_pros_enc_on_speaker else None,
|
||||
)
|
||||
elif self.args.prosody_encoder_type == "resnet":
|
||||
self.prosody_encoder = ResNetProsodyEncoder(
|
||||
input_dim=self.args.hidden_channels,
|
||||
proj_dim=self.args.prosody_embedding_dim,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
|
||||
)
|
||||
|
||||
print(f" > Using the prosody Encoder type {self.args.prosody_encoder_type} with {len(list(self.prosody_encoder.parameters()))} trainable parameters !")
|
||||
|
||||
if self.args.use_prosody_enc_spk_reversal_classifier:
|
||||
self.speaker_reversal_classifier = ReversalClassifier(
|
||||
in_channels=self.args.prosody_embedding_dim,
|
||||
|
@ -1274,9 +1260,6 @@ class Vits(BaseTTS):
|
|||
- pitch: :math:`(B, 1, T_{de})`
|
||||
- dr: :math:`(B, T_{en})`
|
||||
"""
|
||||
if self.args.use_pitch_on_enc_input:
|
||||
o_en = self.pitch_predictor_vocab_emb(o_en)
|
||||
o_en = torch.transpose(o_en, 1, -1) # [b, h, t]
|
||||
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, o_en.size(2)), 1).to(o_en.dtype) # [b, 1, t]
|
||||
|
||||
|
@ -1445,13 +1428,6 @@ class Vits(BaseTTS):
|
|||
g_dp = eg
|
||||
else:
|
||||
g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
pitch_loss = None
|
||||
gt_avg_pitch_emb = None
|
||||
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
||||
if alignments is None:
|
||||
raise RuntimeError(" [!] For condition the pitch on the Text Encoder you need to provide external alignments !")
|
||||
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(x, x_lengths, pitch, alignments.sum(3), g)
|
||||
|
||||
# posterior encoder
|
||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||
|
@ -1466,15 +1442,11 @@ class Vits(BaseTTS):
|
|||
l_pros_emotion = None
|
||||
if self.args.use_prosody_encoder:
|
||||
prosody_encoder_input = z_p if self.args.use_prosody_encoder_z_p_input else z
|
||||
if not self.args.use_pros_enc_input_as_pros_emb:
|
||||
pros_emb, vae_outputs = self.prosody_encoder(
|
||||
prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
|
||||
y_lengths,
|
||||
speaker_embedding=g if self.args.condition_pros_enc_on_speaker else None
|
||||
)
|
||||
else:
|
||||
pros_emb = prosody_encoder_input.mean(2).unsqueeze(1).detach()
|
||||
pros_emb = F.normalize(self.prosody_embedding_squeezer(pros_emb.squeeze(1))).unsqueeze(1)
|
||||
pros_emb, vae_outputs = self.prosody_encoder(
|
||||
prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
|
||||
y_lengths,
|
||||
speaker_embedding=g if self.args.condition_pros_enc_on_speaker else None
|
||||
)
|
||||
|
||||
pros_emb = pros_emb.transpose(1, 2)
|
||||
|
||||
|
@ -1488,8 +1460,7 @@ class Vits(BaseTTS):
|
|||
x_lengths,
|
||||
lang_emb=lang_emb,
|
||||
emo_emb=eg,
|
||||
pros_emb=pros_emb if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None,
|
||||
pitch_emb=gt_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
|
||||
pros_emb=pros_emb if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None
|
||||
)
|
||||
|
||||
# reversal speaker loss to force the encoder to be speaker identity free
|
||||
|
@ -1502,7 +1473,12 @@ class Vits(BaseTTS):
|
|||
if self.args.use_text_enc_emo_classifier:
|
||||
_, l_text_emotion = self.emo_text_enc_classifier(m_p.transpose(1, 2), eid, x_mask=x_mask)
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
# add prosody embedding on x if needed
|
||||
if self.args.use_prosody_encoder and (self.args.use_encoder_conditional_module or self.args.use_z_decoder):
|
||||
x = x + pros_emb.expand(-1, -1, x.size(2))
|
||||
|
||||
# add prosody embedding when necessary
|
||||
if self.args.use_prosody_encoder and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder:
|
||||
if g_dp is None:
|
||||
g_dp = pros_emb
|
||||
else:
|
||||
|
@ -1510,30 +1486,23 @@ class Vits(BaseTTS):
|
|||
|
||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
||||
|
||||
# add pitch
|
||||
pitch_loss = None
|
||||
gt_avg_pitch_emb = None
|
||||
if self.args.use_pitch:
|
||||
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(x, x_lengths, pitch, attn.sum(3), g)
|
||||
x = x + gt_avg_pitch_emb
|
||||
|
||||
z_p_avg = None
|
||||
if self.args.use_latent_discriminator:
|
||||
# average the z_p for the latent discriminator
|
||||
z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze())
|
||||
|
||||
# conditional module
|
||||
conditional_module_loss = None
|
||||
new_m_p = None
|
||||
if self.args.use_encoder_conditional_module:
|
||||
g_cond = None
|
||||
cond_module_input = x
|
||||
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
||||
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(cond_module_input, x_lengths, pitch, attn.sum(3), g)
|
||||
cond_module_input = cond_module_input + gt_avg_pitch_emb
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
if g_cond is None:
|
||||
g_cond = pros_emb
|
||||
else:
|
||||
g_cond = torch.cat([g_cond, pros_emb], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
if g_cond is not None:
|
||||
cond_module_input = torch.cat((cond_module_input, g_cond.expand(-1, -1, cond_module_input.size(2))), dim=1)
|
||||
|
||||
new_m_p = self.encoder_conditional_module(cond_module_input, x_mask) * x_mask
|
||||
new_m_p = self.encoder_conditional_module(x, x_mask) * x_mask
|
||||
if z_p_avg is None:
|
||||
z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze()).detach()
|
||||
else:
|
||||
|
@ -1541,40 +1510,20 @@ class Vits(BaseTTS):
|
|||
|
||||
conditional_module_loss = torch.nn.functional.l1_loss(new_m_p, z_p_avg)
|
||||
|
||||
if self.args.use_pitch and not self.args.use_pitch_on_enc_input and not self.args.use_z_decoder and not self.args.use_encoder_conditional_module:
|
||||
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g)
|
||||
m_p = m_p + gt_avg_pitch_emb
|
||||
|
||||
# expand prior
|
||||
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||
logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||
|
||||
# z decoder
|
||||
z_decoder_loss = None
|
||||
if self.args.use_z_decoder:
|
||||
cond_x = x
|
||||
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
||||
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(cond_x, x_lengths, pitch, attn.sum(3), g)
|
||||
cond_x = cond_x + gt_avg_pitch_emb
|
||||
|
||||
x_expanded = torch.einsum("klmn, kjm -> kjn", [attn, cond_x])
|
||||
# prepare the conditional emb
|
||||
g_dec = g
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||
if g_dec is None:
|
||||
g_dec = eg
|
||||
else:
|
||||
g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1]
|
||||
if self.args.use_prosody_encoder:
|
||||
if g_dec is None:
|
||||
g_dec = pros_emb
|
||||
else:
|
||||
g_dec = torch.cat([g_dec, pros_emb], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
if g_dec is not None:
|
||||
x_expanded = torch.cat((x_expanded, g_dec.expand(-1, -1, x_expanded.size(2))), dim=1)
|
||||
dec_input = torch.einsum("klmn, kjm -> kjn", [attn, x])
|
||||
# add speaker emb
|
||||
if g is not None:
|
||||
dec_input = torch.cat((dec_input, g.expand(-1, -1, dec_input.size(2))), dim=1)
|
||||
|
||||
# decoder pass
|
||||
z_decoder = self.z_decoder(x_expanded, y_mask, g=g_dec)
|
||||
z_decoder = self.z_decoder(dec_input, y_mask, g=None)
|
||||
z_decoder_loss = torch.nn.functional.l1_loss(z_decoder * y_mask, z)
|
||||
|
||||
# select a random feature segment for the waveform decoder
|
||||
|
@ -1723,19 +1672,26 @@ class Vits(BaseTTS):
|
|||
# extract posterior encoder feature
|
||||
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
|
||||
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=ssg)
|
||||
if not self.args.use_pros_enc_input_as_pros_emb:
|
||||
if not self.args.use_prosody_encoder_z_p_input:
|
||||
pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths, speaker_embedding=ssg if self.args.condition_pros_enc_on_speaker else None)
|
||||
else:
|
||||
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=ssg)
|
||||
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths, speaker_embedding=ssg if self.args.condition_pros_enc_on_speaker else None)
|
||||
if not self.args.use_prosody_encoder_z_p_input:
|
||||
pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths, speaker_embedding=ssg if self.args.condition_pros_enc_on_speaker else None)
|
||||
else:
|
||||
prosody_encoder_input = self.flow(z_pro, z_pro_y_mask, g=ssg) if self.args.use_prosody_encoder_z_p_input else z_pro
|
||||
pros_emb = prosody_encoder_input.mean(2).unsqueeze(1)
|
||||
pros_emb = F.normalize(self.prosody_embedding_squeezer(pros_emb.squeeze(1))).unsqueeze(1)
|
||||
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=ssg)
|
||||
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths, speaker_embedding=ssg if self.args.condition_pros_enc_on_speaker else None)
|
||||
|
||||
pros_emb = pros_emb.transpose(1, 2)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||
x,
|
||||
x_lengths,
|
||||
lang_emb=lang_emb,
|
||||
emo_emb=eg,
|
||||
pros_emb=pros_emb if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None
|
||||
)
|
||||
|
||||
# add prosody embedding on x if needed
|
||||
if self.args.use_prosody_encoder and (self.args.use_encoder_conditional_module or self.args.use_z_decoder):
|
||||
x = x + pros_emb.expand(-1, -1, x.size(2))
|
||||
|
||||
# duration predictor
|
||||
g_dp = g if self.args.condition_dp_on_speaker else None
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||
|
@ -1744,20 +1700,7 @@ class Vits(BaseTTS):
|
|||
else:
|
||||
g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
pred_avg_pitch_emb = None
|
||||
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g, pitch_transform=pitch_transform)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||
x,
|
||||
x_lengths,
|
||||
lang_emb=lang_emb,
|
||||
emo_emb=eg,
|
||||
pros_emb=pros_emb if not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None,
|
||||
pitch_emb=pred_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
|
||||
)
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
if self.args.use_prosody_encoder and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder:
|
||||
if g_dp is None:
|
||||
g_dp = pros_emb
|
||||
else:
|
||||
|
@ -1783,25 +1726,13 @@ class Vits(BaseTTS):
|
|||
attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1]
|
||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
||||
|
||||
if self.args.use_pitch and not self.args.use_pitch_on_enc_input and not self.args.use_z_decoder and not self.args.use_encoder_conditional_module:
|
||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g, pitch_transform=pitch_transform)
|
||||
m_p = m_p + pred_avg_pitch_emb
|
||||
pred_avg_pitch_emb = None
|
||||
if self.args.use_pitch:
|
||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g, pitch_transform=pitch_transform)
|
||||
x = x + pred_avg_pitch_emb
|
||||
|
||||
if self.args.use_encoder_conditional_module:
|
||||
g_cond = None
|
||||
cond_module_input = x
|
||||
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(cond_module_input, x_lengths, g_pp=g, pitch_transform=pitch_transform)
|
||||
cond_module_input = cond_module_input + pred_avg_pitch_emb
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
if g_cond is None:
|
||||
g_cond = pros_emb
|
||||
else:
|
||||
g_cond = torch.cat([g_cond, pros_emb], dim=1) # [b, h1+h2, 1]
|
||||
if g_cond is not None:
|
||||
cond_module_input = torch.cat((cond_module_input, g_cond.expand(-1, -1, cond_module_input.size(2))), dim=1)
|
||||
m_p = self.encoder_conditional_module(cond_module_input, x_mask)
|
||||
m_p = self.encoder_conditional_module(x, x_mask)
|
||||
|
||||
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
||||
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||
|
@ -1809,31 +1740,14 @@ class Vits(BaseTTS):
|
|||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||
|
||||
if self.args.use_z_decoder:
|
||||
cond_x = x
|
||||
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(cond_x, x_lengths, g_pp=g, pitch_transform=pitch_transform)
|
||||
cond_x = cond_x + pred_avg_pitch_emb
|
||||
dec_input = torch.matmul(attn.transpose(1, 2), x.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
x_expanded = torch.matmul(attn.transpose(1, 2), cond_x.transpose(1, 2)).transpose(1, 2)
|
||||
# prepare the conditional emb
|
||||
g_dec = g
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||
if g_dec is None:
|
||||
g_dec = eg
|
||||
else:
|
||||
g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1]+
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
if g_dec is None:
|
||||
g_dec = pros_emb
|
||||
else:
|
||||
g_dec = torch.cat([g_dec, pros_emb], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
if g_dec is not None:
|
||||
x_expanded = torch.cat((x_expanded, g_dec.expand(-1, -1, x_expanded.size(2))), dim=1)
|
||||
# add speaker emb
|
||||
if g is not None:
|
||||
dec_input = torch.cat((dec_input, g.expand(-1, -1, dec_input.size(2))), dim=1)
|
||||
|
||||
# decoder pass
|
||||
z = self.z_decoder(x_expanded, y_mask, g=g_dec)
|
||||
z = self.z_decoder(dec_input, y_mask, g=None)
|
||||
else:
|
||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
import glob
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
||||
config = VitsConfig(
|
||||
batch_size=2,
|
||||
eval_batch_size=2,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
compute_pitch=True,
|
||||
f0_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
test_sentences=[
|
||||
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None, "ljspeech-2"],
|
||||
],
|
||||
)
|
||||
# set audio config
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
|
||||
# active multispeaker d-vec mode
|
||||
config.model_args.use_speaker_embedding = True
|
||||
config.model_args.use_d_vector_file = False
|
||||
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.speaker_embedding_channels = 128
|
||||
config.model_args.d_vector_dim = 128
|
||||
|
||||
|
||||
config.model_args.use_precomputed_alignments = True
|
||||
config.model_args.alignments_cache_path = "tests/data/ljspeech/mas_alignments/alignments/"
|
||||
|
||||
# pitch predictor
|
||||
config.model_args.use_pitch = True
|
||||
config.model_args.use_pitch_on_enc_input = False
|
||||
config.model_args.pitch_embedding_dim = 2
|
||||
config.model_args.condition_dp_on_speaker = False
|
||||
|
||||
|
||||
|
||||
# prosody encoder
|
||||
config.model_args.use_prosody_encoder = True
|
||||
config.model_args.prosody_embedding_dim = 64
|
||||
config.model_args.prosody_encoder_type = "resnet"
|
||||
|
||||
config.model_args.use_encoder_conditional_module = True
|
||||
config.model_args.use_z_decoder = False
|
||||
|
||||
config.model_args.use_latent_discriminator = False
|
||||
|
||||
config.save_json(config_path)
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.name ljspeech_test "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||
"--coqpit.test_delay_epochs 0"
|
||||
)
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# Inference using TTS API
|
||||
continue_config_path = os.path.join(continue_path, "config.json")
|
||||
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
speaker_id = "ljspeech-1"
|
||||
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
||||
|
||||
|
||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} "
|
||||
run_cli(inference_command)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
|
@ -52,9 +52,8 @@ config.model_args.use_pitch_on_enc_input = False
|
|||
config.model_args.pitch_embedding_dim = 2
|
||||
config.model_args.condition_dp_on_speaker = False
|
||||
|
||||
|
||||
config.model_args.use_encoder_conditional_module = True
|
||||
config.model_args.use_z_decoder = False
|
||||
config.model_args.use_encoder_conditional_module = False
|
||||
config.model_args.use_z_decoder = True
|
||||
|
||||
config.model_args.use_latent_discriminator = False
|
||||
|
||||
|
|
Loading…
Reference in New Issue