mirror of https://github.com/coqui-ai/TTS.git
Remove VITS End2End loss
This commit is contained in:
parent
ae55bdae6c
commit
a1d0088087
|
@ -685,7 +685,6 @@ class VitsGeneratorLoss(nn.Module):
|
|||
scores_disc_mp=None,
|
||||
feats_disc_mp=None,
|
||||
feats_disc_zp=None,
|
||||
end2end_info=None,
|
||||
):
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -762,32 +761,6 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss += kl_vae_loss
|
||||
return_dict["loss_kl_vae"] = kl_vae_loss
|
||||
|
||||
if end2end_info is not None:
|
||||
|
||||
# gen loss
|
||||
loss_gen_end2end = self.generator_loss(scores_fake=end2end_info["scores_disc_fake"])[0] * self.gen_loss_alpha
|
||||
return_dict["loss_gen_end2end"] = loss_gen_end2end
|
||||
loss += loss_gen_end2end
|
||||
|
||||
# if do not uses soft dtw
|
||||
if end2end_info["z_predicted"] is not None:
|
||||
# loss KL using GT durations
|
||||
z = end2end_info["z"].float()
|
||||
logs_q = end2end_info["logs_q"].float()
|
||||
z_predicted = end2end_info["z_predicted"].float()
|
||||
logs_p = end2end_info["logs_p"].float()
|
||||
z_mask = end2end_info["z_mask"].float()
|
||||
|
||||
kl = logs_p - logs_q - 0.5
|
||||
kl += 0.5 * ((z - z_predicted) ** 2) * torch.exp(-2.0 * logs_p)
|
||||
kl = torch.sum(kl * z_mask)
|
||||
loss_kl_end2end_gt_durations = kl / torch.sum(z_mask)
|
||||
return_dict["loss_kl_end2end_gt_durations"] = loss_kl_end2end_gt_durations
|
||||
loss += loss_kl_end2end_gt_durations
|
||||
else:
|
||||
pass
|
||||
# ToDo: implement soft dtw
|
||||
|
||||
# pass losses to the dict
|
||||
return_dict["loss_gen"] = loss_gen
|
||||
return_dict["loss_kl"] = loss_kl
|
||||
|
@ -822,7 +795,7 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
fake_losses.append(fake_loss.item())
|
||||
return loss, real_losses, fake_losses
|
||||
|
||||
def forward(self, scores_disc_real, scores_disc_fake, scores_disc_zp=None, scores_disc_mp=None, end2end_info=None):
|
||||
def forward(self, scores_disc_real, scores_disc_fake, scores_disc_zp=None, scores_disc_mp=None):
|
||||
return_dict = {}
|
||||
return_dict["loss"] = 0.0
|
||||
loss_disc, loss_disc_real, _ = self.discriminator_loss(
|
||||
|
@ -844,17 +817,6 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
return_dict["loss_disc_latent"] = loss_disc_latent * self.disc_latent_loss_alpha
|
||||
return_dict["loss"] += return_dict["loss_disc_latent"]
|
||||
|
||||
if end2end_info is not None:
|
||||
loss_disc_end2end, loss_disc_real_end2end, _ = self.discriminator_loss(
|
||||
scores_real=end2end_info["scores_disc_real"], scores_fake=end2end_info["scores_disc_fake"],
|
||||
)
|
||||
return_dict["loss_disc_end2end"] = loss_disc_end2end * self.disc_loss_alpha
|
||||
return_dict["loss"] += return_dict["loss_disc_end2end"]
|
||||
|
||||
for i, ldr in enumerate(loss_disc_real_end2end):
|
||||
return_dict[f"loss_disc_end2end_real_{i}"] = ldr
|
||||
|
||||
|
||||
return return_dict
|
||||
|
||||
|
||||
|
|
|
@ -562,10 +562,6 @@ class VitsArgs(Coqpit):
|
|||
use_prosody_conditional_flow_module: bool = False
|
||||
prosody_conditional_flow_module_on_decoder: bool = False
|
||||
|
||||
# end 2 end loss
|
||||
use_end2end_loss: bool = False
|
||||
use_soft_dtw: bool = False
|
||||
|
||||
use_latent_discriminator: bool = False
|
||||
|
||||
detach_dp_input: bool = True
|
||||
|
@ -1069,7 +1065,6 @@ class Vits(BaseTTS):
|
|||
return g
|
||||
|
||||
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
|
||||
predicted_durations = None
|
||||
# find the alignment path
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
with torch.no_grad():
|
||||
|
@ -1092,15 +1087,6 @@ class Vits(BaseTTS):
|
|||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||
)
|
||||
loss_duration = loss_duration / torch.sum(x_mask)
|
||||
if self.args.use_end2end_loss:
|
||||
predicted_durations = self.duration_predictor(
|
||||
x.detach() if self.args.detach_dp_input else x,
|
||||
x_mask,
|
||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
reverse=True,
|
||||
noise_scale=self.inference_noise_scale_dp,
|
||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb
|
||||
)
|
||||
else:
|
||||
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
|
||||
log_durations = self.duration_predictor(
|
||||
|
@ -1109,10 +1095,9 @@ class Vits(BaseTTS):
|
|||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||
)
|
||||
predicted_durations = log_durations
|
||||
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||
outputs["loss_duration"] = loss_duration
|
||||
return outputs, attn, predicted_durations
|
||||
return outputs, attn
|
||||
|
||||
def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None):
|
||||
spec_segment_size = self.spec_segment_size
|
||||
|
@ -1268,7 +1253,7 @@ class Vits(BaseTTS):
|
|||
else:
|
||||
g_dp = torch.cat([g_dp, pros_emb], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
outputs, attn, predicted_durations = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
||||
|
||||
# expand prior
|
||||
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||
|
@ -1322,43 +1307,6 @@ class Vits(BaseTTS):
|
|||
else:
|
||||
gt_cons_emb, syn_cons_emb = None, None
|
||||
|
||||
end2end_dict = None
|
||||
if self.args.use_end2end_loss:
|
||||
# predicted_durations
|
||||
w = torch.exp(predicted_durations) * x_mask * self.length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
y_lengths_end2end = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_mask_end2end = sequence_mask(y_lengths_end2end, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec]
|
||||
|
||||
attn_mask = x_mask * y_mask_end2end.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1]
|
||||
attn_end2end = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
||||
|
||||
m_p_end2end = torch.matmul(attn_end2end.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
||||
logs_p_end2end = torch.matmul(attn_end2end.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
z_p_end2end = m_p_end2end * y_mask_end2end #+ torch.randn_like(m_p_end2end) * torch.exp(logs_p_end2end) * self.inference_noise_scale
|
||||
|
||||
# conditional module
|
||||
if self.args.use_prosody_conditional_flow_module:
|
||||
if self.args.prosody_conditional_flow_module_on_decoder:
|
||||
z_p_end2end = self.prosody_conditional_module(z_p_end2end, y_mask_end2end, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb, reverse=True)
|
||||
|
||||
z_end2end = self.flow(z_p_end2end, y_mask_end2end, g=g, reverse=True)
|
||||
|
||||
# interpolate z if needed
|
||||
z_end2end, _, _, y_mask_end2end = self.upsampling_z(z, y_lengths=y_lengths_end2end, y_mask=y_mask_end2end)
|
||||
# z_slice_end2end, spec_segment_size, slice_ids_end2end, _ = self.upsampling_z(z_slice_end2end, slice_ids=slice_ids_end2end)
|
||||
|
||||
# generate all z using the vocoder
|
||||
o_end2end = self.waveform_decoder(z_end2end, g=g)
|
||||
wav_seg_end2end = waveform
|
||||
|
||||
z_predicted_gt_durations = None
|
||||
if not self.args.use_soft_dtw:
|
||||
z_predicted_gt_durations = self.flow(m_p_expanded * y_mask, y_mask, g=g, reverse=True)
|
||||
|
||||
end2end_dict = {"logs_p_end2end": logs_p_end2end, "logs_p": logs_p_expanded, "logs_q": logs_q, "z_mask": y_mask, "z_mask_end2end": y_mask_end2end, "z": z, "z_predicted_end2end": z_end2end, "z_predicted": z_predicted_gt_durations, "model_outputs": o_end2end, "waveform_seg": wav_seg_end2end}
|
||||
|
||||
outputs.update(
|
||||
{
|
||||
"model_outputs": o,
|
||||
|
@ -1377,8 +1325,7 @@ class Vits(BaseTTS):
|
|||
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
|
||||
"loss_prosody_enc_emo_classifier": l_pros_emotion,
|
||||
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
||||
"loss_text_enc_emo_classifier": l_text_emotion,
|
||||
"end2end_info": end2end_dict,
|
||||
"loss_text_enc_emo_classifier": l_text_emotion
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
@ -1686,21 +1633,13 @@ class Vits(BaseTTS):
|
|||
outputs["model_outputs"].detach(), outputs["waveform_seg"], outputs["m_p"].detach(), outputs["z_p"].detach()
|
||||
)
|
||||
|
||||
end2end_info = None
|
||||
if self.args.use_end2end_loss:
|
||||
scores_disc_fake_end2end, _, scores_disc_real_end2end, _, _, _, _, _ = self.disc(
|
||||
outputs["end2end_info"]["model_outputs"].detach(), self.model_outputs_cache["end2end_info"]["waveform_seg"]
|
||||
)
|
||||
end2end_info = {"scores_disc_real": scores_disc_real_end2end, "scores_disc_fake": scores_disc_fake_end2end}
|
||||
|
||||
# compute loss
|
||||
with autocast(enabled=False): # use float32 for the criterion
|
||||
loss_dict = criterion[optimizer_idx](
|
||||
scores_disc_real,
|
||||
scores_disc_fake,
|
||||
scores_disc_zp,
|
||||
scores_disc_mp,
|
||||
end2end_info=end2end_info,
|
||||
scores_disc_mp
|
||||
)
|
||||
return outputs, loss_dict
|
||||
|
||||
|
@ -1735,14 +1674,6 @@ class Vits(BaseTTS):
|
|||
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"], self.model_outputs_cache["m_p"], self.model_outputs_cache["z_p"].detach()
|
||||
)
|
||||
|
||||
if self.args.use_end2end_loss:
|
||||
scores_disc_fake_end2end, feats_disc_fake_end2end, _, feats_disc_real_end2end, _, _, _, _, _ = self.disc(
|
||||
self.model_outputs_cache["end2end_info"]["model_outputs"], self.model_outputs_cache["end2end_info"]["waveform_seg"]
|
||||
)
|
||||
self.model_outputs_cache["end2end_info"]["scores_disc_fake"] = scores_disc_fake_end2end
|
||||
self.model_outputs_cache["end2end_info"]["feats_disc_fake"] = feats_disc_fake_end2end
|
||||
self.model_outputs_cache["end2end_info"]["feats_disc_real"] = feats_disc_real_end2end
|
||||
|
||||
# compute losses
|
||||
with autocast(enabled=False): # use float32 for the criterion
|
||||
loss_dict = criterion[optimizer_idx](
|
||||
|
@ -1768,8 +1699,7 @@ class Vits(BaseTTS):
|
|||
loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"],
|
||||
scores_disc_mp=scores_disc_mp,
|
||||
feats_disc_mp=feats_disc_mp,
|
||||
feats_disc_zp=feats_disc_zp,
|
||||
end2end_info=self.model_outputs_cache["end2end_info"],
|
||||
feats_disc_zp=feats_disc_zp
|
||||
)
|
||||
|
||||
return self.model_outputs_cache, loss_dict
|
||||
|
|
Loading…
Reference in New Issue