Add latent discriminator over unexpanded m_p

This commit is contained in:
Edresson Casanova 2022-06-18 14:26:56 +00:00
parent 569decba64
commit 3165c55fae
2 changed files with 13 additions and 5 deletions

View File

@ -675,7 +675,6 @@ class VitsArgs(Coqpit):
use_noise_scale_predictor: bool = False
use_latent_discriminator: bool = False
use_avg_feature_on_latent_discriminator: bool = False
# Pitch predictor
use_pitch_on_enc_input: bool = False
@ -1493,6 +1492,11 @@ class Vits(BaseTTS):
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp)
m_p = m_p + gt_avg_pitch_emb
z_p_avg = None
if self.args.use_latent_discriminator:
# average the z_p for the latent discriminator
z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze())
# expand prior
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
@ -1553,6 +1557,8 @@ class Vits(BaseTTS):
{
"model_outputs": o,
"alignments": attn.squeeze(1),
"m_p_unexpanded": m_p,
"z_p_avg": z_p_avg,
"m_p": m_p_expanded,
"logs_p": logs_p_expanded,
"z": z,
@ -1885,8 +1891,8 @@ class Vits(BaseTTS):
scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _ = self.disc(
outputs["model_outputs"].detach(),
outputs["waveform_seg"],
outputs["m_p"].detach(),
outputs["z_p"].detach(),
outputs["m_p_unexpanded"].detach(),
outputs["z_p_avg"].detach(),
)
# compute loss
@ -1933,8 +1939,8 @@ class Vits(BaseTTS):
) = self.disc(
self.model_outputs_cache["model_outputs"],
self.model_outputs_cache["waveform_seg"],
self.model_outputs_cache["m_p"],
self.model_outputs_cache["z_p"].detach(),
self.model_outputs_cache["m_p_unexpanded"],
self.model_outputs_cache["z_p_avg"].detach(),
)
# compute losses

View File

@ -53,6 +53,8 @@ config.model_args.pitch_embedding_dim = 2
config.model_args.condition_dp_on_speaker = True
config.model_args.use_latent_discriminator = True
config.save_json(config_path)
# train the model for one epoch
command_train = (