mirror of https://github.com/coqui-ai/TTS.git
Add pitch and prosody encoder suport for the conditional module
This commit is contained in:
parent
ef27039190
commit
ae1f443b35
|
@ -118,8 +118,8 @@ class VitsConfig(BaseTTSConfig):
|
||||||
speaker_classifier_loss_alpha: float = 2.0
|
speaker_classifier_loss_alpha: float = 2.0
|
||||||
emotion_classifier_loss_alpha: float = 4.0
|
emotion_classifier_loss_alpha: float = 4.0
|
||||||
prosody_encoder_kl_loss_alpha: float = 5.0
|
prosody_encoder_kl_loss_alpha: float = 5.0
|
||||||
disc_latent_loss_alpha: float = 5.0
|
disc_latent_loss_alpha: float = 1.0
|
||||||
gen_latent_loss_alpha: float = 5.0
|
gen_latent_loss_alpha: float = 1.0
|
||||||
feat_latent_loss_alpha: float = 108.0
|
feat_latent_loss_alpha: float = 108.0
|
||||||
pitch_loss_alpha: float = 5.0
|
pitch_loss_alpha: float = 5.0
|
||||||
z_decoder_loss_alpha: float = 45.0
|
z_decoder_loss_alpha: float = 45.0
|
||||||
|
|
|
@ -788,7 +788,7 @@ class Vits(BaseTTS):
|
||||||
self.args.dropout_p_text_encoder,
|
self.args.dropout_p_text_encoder,
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim,
|
||||||
emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_noise_scale_predictor else 0,
|
emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_noise_scale_predictor else 0,
|
||||||
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module else 0,
|
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else 0,
|
||||||
pitch_dim=self.args.pitch_embedding_dim if self.args.use_pitch and self.args.use_pitch_on_enc_input else 0,
|
pitch_dim=self.args.pitch_embedding_dim if self.args.use_pitch and self.args.use_pitch_on_enc_input else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -827,7 +827,7 @@ class Vits(BaseTTS):
|
||||||
) and not self.args.use_noise_scale_predictor:
|
) and not self.args.use_noise_scale_predictor:
|
||||||
dp_extra_inp_dim += self.args.emotion_embedding_dim
|
dp_extra_inp_dim += self.args.emotion_embedding_dim
|
||||||
|
|
||||||
if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module:
|
if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder:
|
||||||
dp_extra_inp_dim += self.args.prosody_embedding_dim
|
dp_extra_inp_dim += self.args.prosody_embedding_dim
|
||||||
|
|
||||||
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
||||||
|
@ -896,7 +896,7 @@ class Vits(BaseTTS):
|
||||||
self.args.pitch_predictor_hidden_channels,
|
self.args.pitch_predictor_hidden_channels,
|
||||||
self.args.pitch_predictor_kernel_size,
|
self.args.pitch_predictor_kernel_size,
|
||||||
self.args.pitch_predictor_dropout_p,
|
self.args.pitch_predictor_dropout_p,
|
||||||
cond_channels=dp_cond_embedding_dim,
|
cond_channels=self.cond_embedding_dim,
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1477,7 +1477,7 @@ class Vits(BaseTTS):
|
||||||
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
||||||
if alignments is None:
|
if alignments is None:
|
||||||
raise RuntimeError(" [!] For condition the pitch on the Text Encoder you need to provide external alignments !")
|
raise RuntimeError(" [!] For condition the pitch on the Text Encoder you need to provide external alignments !")
|
||||||
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(x, x_lengths, pitch, alignments.sum(3), g_dp)
|
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(x, x_lengths, pitch, alignments.sum(3), g)
|
||||||
|
|
||||||
# posterior encoder
|
# posterior encoder
|
||||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||||
|
@ -1514,7 +1514,7 @@ class Vits(BaseTTS):
|
||||||
x_lengths,
|
x_lengths,
|
||||||
lang_emb=lang_emb,
|
lang_emb=lang_emb,
|
||||||
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
||||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module else None,
|
pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None,
|
||||||
pitch_emb=gt_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
|
pitch_emb=gt_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1536,10 +1536,20 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
||||||
|
|
||||||
|
z_p_avg = None
|
||||||
|
if self.args.use_latent_discriminator:
|
||||||
|
# average the z_p for the latent discriminator
|
||||||
|
z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze())
|
||||||
|
|
||||||
conditional_module_loss = None
|
conditional_module_loss = None
|
||||||
|
new_m_p = None
|
||||||
if self.args.use_encoder_conditional_module:
|
if self.args.use_encoder_conditional_module:
|
||||||
g_cond = None
|
g_cond = None
|
||||||
cond_module_input = x
|
cond_module_input = x
|
||||||
|
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
||||||
|
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(cond_module_input, x_lengths, pitch, attn.sum(3), g)
|
||||||
|
cond_module_input = cond_module_input + gt_avg_pitch_emb
|
||||||
|
|
||||||
if self.args.use_prosody_encoder:
|
if self.args.use_prosody_encoder:
|
||||||
if g_cond is None:
|
if g_cond is None:
|
||||||
g_cond = pros_emb
|
g_cond = pros_emb
|
||||||
|
@ -1549,18 +1559,17 @@ class Vits(BaseTTS):
|
||||||
if g_cond is not None:
|
if g_cond is not None:
|
||||||
cond_module_input = torch.cat((cond_module_input, g_cond.expand(-1, -1, cond_module_input.size(2))), dim=1)
|
cond_module_input = torch.cat((cond_module_input, g_cond.expand(-1, -1, cond_module_input.size(2))), dim=1)
|
||||||
|
|
||||||
new_m_p = self.encoder_conditional_module(cond_module_input, x_mask)
|
new_m_p = self.encoder_conditional_module(cond_module_input, x_mask) * x_mask
|
||||||
z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze()).detach()
|
if z_p_avg is None:
|
||||||
conditional_module_loss = torch.nn.functional.l1_loss(new_m_p * x_mask, z_p_avg)
|
z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze()).detach()
|
||||||
|
else:
|
||||||
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
z_p_avg = z_p_avg.detach()
|
||||||
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp)
|
|
||||||
m_p = m_p + gt_avg_pitch_emb
|
|
||||||
|
|
||||||
z_p_avg = None
|
conditional_module_loss = torch.nn.functional.l1_loss(new_m_p, z_p_avg)
|
||||||
if self.args.use_latent_discriminator:
|
|
||||||
# average the z_p for the latent discriminator
|
if self.args.use_pitch and not self.args.use_pitch_on_enc_input and not self.args.use_z_decoder and not self.args.use_encoder_conditional_module:
|
||||||
z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze())
|
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g)
|
||||||
|
m_p = m_p + gt_avg_pitch_emb
|
||||||
|
|
||||||
# expand prior
|
# expand prior
|
||||||
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||||
|
@ -1568,7 +1577,12 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
z_decoder_loss = None
|
z_decoder_loss = None
|
||||||
if self.args.use_z_decoder:
|
if self.args.use_z_decoder:
|
||||||
x_expanded = torch.einsum("klmn, kjm -> kjn", [attn, x])
|
cond_x = x
|
||||||
|
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
||||||
|
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(cond_x, x_lengths, pitch, attn.sum(3), g)
|
||||||
|
cond_x = cond_x + gt_avg_pitch_emb
|
||||||
|
|
||||||
|
x_expanded = torch.einsum("klmn, kjm -> kjn", [attn, cond_x])
|
||||||
# prepare the conditional emb
|
# prepare the conditional emb
|
||||||
g_dec = g
|
g_dec = g
|
||||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||||
|
@ -1646,7 +1660,7 @@ class Vits(BaseTTS):
|
||||||
{
|
{
|
||||||
"model_outputs": o,
|
"model_outputs": o,
|
||||||
"alignments": attn.squeeze(1),
|
"alignments": attn.squeeze(1),
|
||||||
"m_p_unexpanded": m_p,
|
"m_p_unexpanded": m_p if new_m_p is None else new_m_p,
|
||||||
"z_p_avg": z_p_avg,
|
"z_p_avg": z_p_avg,
|
||||||
"m_p": m_p_expanded,
|
"m_p": m_p_expanded,
|
||||||
"logs_p": logs_p_expanded,
|
"logs_p": logs_p_expanded,
|
||||||
|
@ -1774,14 +1788,14 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
pred_avg_pitch_emb = None
|
pred_avg_pitch_emb = None
|
||||||
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
||||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g_dp, pitch_transform=pitch_transform)
|
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g, pitch_transform=pitch_transform)
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||||
x,
|
x,
|
||||||
x_lengths,
|
x_lengths,
|
||||||
lang_emb=lang_emb,
|
lang_emb=lang_emb,
|
||||||
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
||||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module else None,
|
pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module and not self.args.use_z_decoder else None,
|
||||||
pitch_emb=pred_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
|
pitch_emb=pred_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1811,19 +1825,22 @@ class Vits(BaseTTS):
|
||||||
attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1]
|
attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1]
|
||||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
||||||
|
|
||||||
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
if self.args.use_pitch and not self.args.use_pitch_on_enc_input and not self.args.use_z_decoder and not self.args.use_encoder_conditional_module:
|
||||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp, pitch_transform=pitch_transform)
|
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g, pitch_transform=pitch_transform)
|
||||||
m_p = m_p + pred_avg_pitch_emb
|
m_p = m_p + pred_avg_pitch_emb
|
||||||
|
|
||||||
if self.args.use_encoder_conditional_module:
|
if self.args.use_encoder_conditional_module:
|
||||||
g_cond = None
|
g_cond = None
|
||||||
cond_module_input = x
|
cond_module_input = x
|
||||||
|
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
||||||
|
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(cond_module_input, x_lengths, g_pp=g, pitch_transform=pitch_transform)
|
||||||
|
cond_module_input = cond_module_input + pred_avg_pitch_emb
|
||||||
|
|
||||||
if self.args.use_prosody_encoder:
|
if self.args.use_prosody_encoder:
|
||||||
if g_cond is None:
|
if g_cond is None:
|
||||||
g_cond = pros_emb
|
g_cond = pros_emb
|
||||||
else:
|
else:
|
||||||
g_cond = torch.cat([g_cond, pros_emb], dim=1) # [b, h1+h2, 1]
|
g_cond = torch.cat([g_cond, pros_emb], dim=1) # [b, h1+h2, 1]
|
||||||
|
|
||||||
if g_cond is not None:
|
if g_cond is not None:
|
||||||
cond_module_input = torch.cat((cond_module_input, g_cond.expand(-1, -1, cond_module_input.size(2))), dim=1)
|
cond_module_input = torch.cat((cond_module_input, g_cond.expand(-1, -1, cond_module_input.size(2))), dim=1)
|
||||||
m_p = self.encoder_conditional_module(cond_module_input, x_mask)
|
m_p = self.encoder_conditional_module(cond_module_input, x_mask)
|
||||||
|
@ -1850,14 +1867,20 @@ class Vits(BaseTTS):
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||||
|
|
||||||
if self.args.use_z_decoder:
|
if self.args.use_z_decoder:
|
||||||
x_expanded = torch.matmul(attn.transpose(1, 2), x.transpose(1, 2)).transpose(1, 2)
|
cond_x = x
|
||||||
|
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
||||||
|
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(cond_x, x_lengths, g_pp=g, pitch_transform=pitch_transform)
|
||||||
|
cond_x = cond_x + pred_avg_pitch_emb
|
||||||
|
|
||||||
|
x_expanded = torch.matmul(attn.transpose(1, 2), cond_x.transpose(1, 2)).transpose(1, 2)
|
||||||
# prepare the conditional emb
|
# prepare the conditional emb
|
||||||
g_dec = g
|
g_dec = g
|
||||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||||
if g_dec is None:
|
if g_dec is None:
|
||||||
g_dec = eg
|
g_dec = eg
|
||||||
else:
|
else:
|
||||||
g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1]
|
g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1]+
|
||||||
|
|
||||||
if self.args.use_prosody_encoder:
|
if self.args.use_prosody_encoder:
|
||||||
if g_dec is None:
|
if g_dec is None:
|
||||||
g_dec = pros_emb
|
g_dec = pros_emb
|
||||||
|
@ -2653,4 +2676,4 @@ class VitsCharacters(BaseCharacters):
|
||||||
blank=self._blank,
|
blank=self._blank,
|
||||||
is_unique=False,
|
is_unique=False,
|
||||||
is_sorted=True,
|
is_sorted=True,
|
||||||
)
|
)
|
|
@ -48,12 +48,15 @@ config.model_args.alignments_cache_path = "tests/data/ljspeech/mas_alignments/al
|
||||||
|
|
||||||
# pitch predictor
|
# pitch predictor
|
||||||
config.model_args.use_pitch = True
|
config.model_args.use_pitch = True
|
||||||
config.model_args.use_pitch_on_enc_input = True
|
config.model_args.use_pitch_on_enc_input = False
|
||||||
config.model_args.pitch_embedding_dim = 2
|
config.model_args.pitch_embedding_dim = 2
|
||||||
config.model_args.condition_dp_on_speaker = True
|
config.model_args.condition_dp_on_speaker = False
|
||||||
|
|
||||||
|
|
||||||
config.model_args.use_latent_discriminator = True
|
config.model_args.use_encoder_conditional_module = True
|
||||||
|
config.model_args.use_z_decoder = False
|
||||||
|
|
||||||
|
config.model_args.use_latent_discriminator = False
|
||||||
|
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
# train the model for one epoch
|
# train the model for one epoch
|
||||||
|
|
|
@ -46,7 +46,8 @@ config.model_args.d_vector_dim = 128
|
||||||
config.model_args.use_prosody_encoder = True
|
config.model_args.use_prosody_encoder = True
|
||||||
config.model_args.prosody_embedding_dim = 64
|
config.model_args.prosody_embedding_dim = 64
|
||||||
|
|
||||||
config.model_args.use_encoder_conditional_module = True
|
config.model_args.use_z_decoder = True
|
||||||
|
config.model_args.use_encoder_conditional_module = False
|
||||||
|
|
||||||
# active classifier
|
# active classifier
|
||||||
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||||
|
|
Loading…
Reference in New Issue