mirror of https://github.com/coqui-ai/TTS.git
Add latent discriminator over unexpanded m_p
This commit is contained in:
parent
569decba64
commit
3165c55fae
|
@ -675,7 +675,6 @@ class VitsArgs(Coqpit):
|
|||
|
||||
use_noise_scale_predictor: bool = False
|
||||
use_latent_discriminator: bool = False
|
||||
use_avg_feature_on_latent_discriminator: bool = False
|
||||
|
||||
# Pitch predictor
|
||||
use_pitch_on_enc_input: bool = False
|
||||
|
@ -1493,6 +1492,11 @@ class Vits(BaseTTS):
|
|||
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp)
|
||||
m_p = m_p + gt_avg_pitch_emb
|
||||
|
||||
z_p_avg = None
|
||||
if self.args.use_latent_discriminator:
|
||||
# average the z_p for the latent discriminator
|
||||
z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze())
|
||||
|
||||
# expand prior
|
||||
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||
logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||
|
@ -1553,6 +1557,8 @@ class Vits(BaseTTS):
|
|||
{
|
||||
"model_outputs": o,
|
||||
"alignments": attn.squeeze(1),
|
||||
"m_p_unexpanded": m_p,
|
||||
"z_p_avg": z_p_avg,
|
||||
"m_p": m_p_expanded,
|
||||
"logs_p": logs_p_expanded,
|
||||
"z": z,
|
||||
|
@ -1885,8 +1891,8 @@ class Vits(BaseTTS):
|
|||
scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _ = self.disc(
|
||||
outputs["model_outputs"].detach(),
|
||||
outputs["waveform_seg"],
|
||||
outputs["m_p"].detach(),
|
||||
outputs["z_p"].detach(),
|
||||
outputs["m_p_unexpanded"].detach(),
|
||||
outputs["z_p_avg"].detach(),
|
||||
)
|
||||
|
||||
# compute loss
|
||||
|
@ -1933,8 +1939,8 @@ class Vits(BaseTTS):
|
|||
) = self.disc(
|
||||
self.model_outputs_cache["model_outputs"],
|
||||
self.model_outputs_cache["waveform_seg"],
|
||||
self.model_outputs_cache["m_p"],
|
||||
self.model_outputs_cache["z_p"].detach(),
|
||||
self.model_outputs_cache["m_p_unexpanded"],
|
||||
self.model_outputs_cache["z_p_avg"].detach(),
|
||||
)
|
||||
|
||||
# compute losses
|
||||
|
|
|
@ -53,6 +53,8 @@ config.model_args.pitch_embedding_dim = 2
|
|||
config.model_args.condition_dp_on_speaker = True
|
||||
|
||||
|
||||
config.model_args.use_latent_discriminator = True
|
||||
|
||||
config.save_json(config_path)
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
|
|
Loading…
Reference in New Issue