mirror of https://github.com/coqui-ai/TTS.git
Add latent discriminator over unexpanded m_p
This commit is contained in:
parent
569decba64
commit
3165c55fae
|
@ -675,7 +675,6 @@ class VitsArgs(Coqpit):
|
||||||
|
|
||||||
use_noise_scale_predictor: bool = False
|
use_noise_scale_predictor: bool = False
|
||||||
use_latent_discriminator: bool = False
|
use_latent_discriminator: bool = False
|
||||||
use_avg_feature_on_latent_discriminator: bool = False
|
|
||||||
|
|
||||||
# Pitch predictor
|
# Pitch predictor
|
||||||
use_pitch_on_enc_input: bool = False
|
use_pitch_on_enc_input: bool = False
|
||||||
|
@ -1493,6 +1492,11 @@ class Vits(BaseTTS):
|
||||||
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp)
|
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp)
|
||||||
m_p = m_p + gt_avg_pitch_emb
|
m_p = m_p + gt_avg_pitch_emb
|
||||||
|
|
||||||
|
z_p_avg = None
|
||||||
|
if self.args.use_latent_discriminator:
|
||||||
|
# average the z_p for the latent discriminator
|
||||||
|
z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze())
|
||||||
|
|
||||||
# expand prior
|
# expand prior
|
||||||
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||||
logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||||
|
@ -1553,6 +1557,8 @@ class Vits(BaseTTS):
|
||||||
{
|
{
|
||||||
"model_outputs": o,
|
"model_outputs": o,
|
||||||
"alignments": attn.squeeze(1),
|
"alignments": attn.squeeze(1),
|
||||||
|
"m_p_unexpanded": m_p,
|
||||||
|
"z_p_avg": z_p_avg,
|
||||||
"m_p": m_p_expanded,
|
"m_p": m_p_expanded,
|
||||||
"logs_p": logs_p_expanded,
|
"logs_p": logs_p_expanded,
|
||||||
"z": z,
|
"z": z,
|
||||||
|
@ -1885,8 +1891,8 @@ class Vits(BaseTTS):
|
||||||
scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _ = self.disc(
|
scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _ = self.disc(
|
||||||
outputs["model_outputs"].detach(),
|
outputs["model_outputs"].detach(),
|
||||||
outputs["waveform_seg"],
|
outputs["waveform_seg"],
|
||||||
outputs["m_p"].detach(),
|
outputs["m_p_unexpanded"].detach(),
|
||||||
outputs["z_p"].detach(),
|
outputs["z_p_avg"].detach(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
|
@ -1933,8 +1939,8 @@ class Vits(BaseTTS):
|
||||||
) = self.disc(
|
) = self.disc(
|
||||||
self.model_outputs_cache["model_outputs"],
|
self.model_outputs_cache["model_outputs"],
|
||||||
self.model_outputs_cache["waveform_seg"],
|
self.model_outputs_cache["waveform_seg"],
|
||||||
self.model_outputs_cache["m_p"],
|
self.model_outputs_cache["m_p_unexpanded"],
|
||||||
self.model_outputs_cache["z_p"].detach(),
|
self.model_outputs_cache["z_p_avg"].detach(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
|
|
|
@ -53,6 +53,8 @@ config.model_args.pitch_embedding_dim = 2
|
||||||
config.model_args.condition_dp_on_speaker = True
|
config.model_args.condition_dp_on_speaker = True
|
||||||
|
|
||||||
|
|
||||||
|
config.model_args.use_latent_discriminator = True
|
||||||
|
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
# train the model for one epoch
|
# train the model for one epoch
|
||||||
command_train = (
|
command_train = (
|
||||||
|
|
Loading…
Reference in New Issue