mirror of https://github.com/coqui-ai/TTS.git
Add end2end VITS loss
This commit is contained in:
parent
ec8c8dc5a2
commit
4e94b46d5e
|
@ -329,7 +329,12 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
|
||||||
if isinstance(ignored_speakers, list):
|
if isinstance(ignored_speakers, list):
|
||||||
if speaker_id in ignored_speakers:
|
if speaker_id in ignored_speakers:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if os.path.exists(wav_file):
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id})
|
||||||
|
else:
|
||||||
|
print(f" [!] wav files don't exist - {wav_file}")
|
||||||
|
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -671,6 +671,7 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
loss_prosody_enc_emo_classifier=None,
|
loss_prosody_enc_emo_classifier=None,
|
||||||
loss_text_enc_spk_rev_classifier=None,
|
loss_text_enc_spk_rev_classifier=None,
|
||||||
loss_text_enc_emo_classifier=None,
|
loss_text_enc_emo_classifier=None,
|
||||||
|
end2end_info=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -737,6 +738,38 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
loss += kl_vae_loss
|
loss += kl_vae_loss
|
||||||
return_dict["loss_kl_vae"] = kl_vae_loss
|
return_dict["loss_kl_vae"] = kl_vae_loss
|
||||||
|
|
||||||
|
if end2end_info is not None:
|
||||||
|
# do not compute feature loss because for it we need waves segments with the same length
|
||||||
|
'''loss_feat_end2end = (
|
||||||
|
self.feature_loss(feats_real=end2end_info["feats_disc_real"], feats_generated=end2end_info["feats_disc_fake"]) * self.feat_loss_alpha
|
||||||
|
)
|
||||||
|
return_dict["loss_feat_end2end"] = loss_feat_end2end
|
||||||
|
loss += loss_feat_end2end'''
|
||||||
|
|
||||||
|
# gen loss
|
||||||
|
loss_gen_end2end = self.generator_loss(scores_fake=end2end_info["scores_disc_fake"])[0] * self.gen_loss_alpha
|
||||||
|
return_dict["loss_gen_end2end"] = loss_gen_end2end
|
||||||
|
loss += loss_gen_end2end
|
||||||
|
|
||||||
|
# if do not uses soft dtw
|
||||||
|
if end2end_info["z_predicted"] is not None:
|
||||||
|
# loss KL using GT durations
|
||||||
|
z = end2end_info["z"].float()
|
||||||
|
logs_q = end2end_info["logs_q"].float()
|
||||||
|
z_predicted = end2end_info["z_predicted"].float()
|
||||||
|
logs_p = end2end_info["logs_p"].float()
|
||||||
|
z_mask = end2end_info["z_mask"].float()
|
||||||
|
|
||||||
|
kl = logs_p - logs_q - 0.5
|
||||||
|
kl += 0.5 * ((z - z_predicted) ** 2) * torch.exp(-2.0 * logs_p)
|
||||||
|
kl = torch.sum(kl * z_mask)
|
||||||
|
loss_kl_end2end_gt_durations = kl / torch.sum(z_mask)
|
||||||
|
return_dict["loss_kl_end2end_gt_durations"] = loss_kl_end2end_gt_durations
|
||||||
|
loss += loss_kl_end2end_gt_durations
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
# ToDo: implement soft dtw
|
||||||
|
|
||||||
# pass losses to the dict
|
# pass losses to the dict
|
||||||
return_dict["loss_gen"] = loss_gen
|
return_dict["loss_gen"] = loss_gen
|
||||||
return_dict["loss_kl"] = loss_kl
|
return_dict["loss_kl"] = loss_kl
|
||||||
|
@ -767,7 +800,7 @@ class VitsDiscriminatorLoss(nn.Module):
|
||||||
fake_losses.append(fake_loss.item())
|
fake_losses.append(fake_loss.item())
|
||||||
return loss, real_losses, fake_losses
|
return loss, real_losses, fake_losses
|
||||||
|
|
||||||
def forward(self, scores_disc_real, scores_disc_fake):
|
def forward(self, scores_disc_real, scores_disc_fake, end2end_info=None):
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
return_dict = {}
|
return_dict = {}
|
||||||
loss_disc, loss_disc_real, _ = self.discriminator_loss(
|
loss_disc, loss_disc_real, _ = self.discriminator_loss(
|
||||||
|
@ -779,6 +812,18 @@ class VitsDiscriminatorLoss(nn.Module):
|
||||||
|
|
||||||
for i, ldr in enumerate(loss_disc_real):
|
for i, ldr in enumerate(loss_disc_real):
|
||||||
return_dict[f"loss_disc_real_{i}"] = ldr
|
return_dict[f"loss_disc_real_{i}"] = ldr
|
||||||
|
|
||||||
|
if end2end_info is not None:
|
||||||
|
loss_disc_end2end, loss_disc_real_end2end, _ = self.discriminator_loss(
|
||||||
|
scores_real=end2end_info["scores_disc_real"], scores_fake=end2end_info["scores_disc_fake"],
|
||||||
|
)
|
||||||
|
return_dict["loss_disc_end2end"] = loss_disc_end2end * self.disc_loss_alpha
|
||||||
|
return_dict["loss"] += return_dict["loss_disc_end2end"]
|
||||||
|
|
||||||
|
for i, ldr in enumerate(loss_disc_real_end2end):
|
||||||
|
return_dict[f"loss_disc_end2end_real_{i}"] = ldr
|
||||||
|
|
||||||
|
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -559,6 +559,10 @@ class VitsArgs(Coqpit):
|
||||||
use_prosody_conditional_flow_module: bool = False
|
use_prosody_conditional_flow_module: bool = False
|
||||||
prosody_conditional_flow_module_on_decoder: bool = False
|
prosody_conditional_flow_module_on_decoder: bool = False
|
||||||
|
|
||||||
|
# end 2 end loss
|
||||||
|
use_end2end_loss: bool = False
|
||||||
|
use_soft_dtw: bool = False
|
||||||
|
|
||||||
detach_dp_input: bool = True
|
detach_dp_input: bool = True
|
||||||
use_language_embedding: bool = False
|
use_language_embedding: bool = False
|
||||||
embedded_language_dim: int = 4
|
embedded_language_dim: int = 4
|
||||||
|
@ -1037,6 +1041,7 @@ class Vits(BaseTTS):
|
||||||
return g
|
return g
|
||||||
|
|
||||||
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
|
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
|
||||||
|
predicted_durations = None
|
||||||
# find the alignment path
|
# find the alignment path
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -1059,6 +1064,15 @@ class Vits(BaseTTS):
|
||||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||||
)
|
)
|
||||||
loss_duration = loss_duration / torch.sum(x_mask)
|
loss_duration = loss_duration / torch.sum(x_mask)
|
||||||
|
if self.args.use_end2end_loss:
|
||||||
|
predicted_durations = self.duration_predictor(
|
||||||
|
x.detach() if self.args.detach_dp_input else x,
|
||||||
|
x_mask,
|
||||||
|
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||||
|
reverse=True,
|
||||||
|
noise_scale=self.inference_noise_scale_dp,
|
||||||
|
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
|
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
|
||||||
log_durations = self.duration_predictor(
|
log_durations = self.duration_predictor(
|
||||||
|
@ -1067,9 +1081,10 @@ class Vits(BaseTTS):
|
||||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||||
)
|
)
|
||||||
|
predicted_durations = log_durations
|
||||||
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||||
outputs["loss_duration"] = loss_duration
|
outputs["loss_duration"] = loss_duration
|
||||||
return outputs, attn
|
return outputs, attn, predicted_durations
|
||||||
|
|
||||||
def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None):
|
def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None):
|
||||||
spec_segment_size = self.spec_segment_size
|
spec_segment_size = self.spec_segment_size
|
||||||
|
@ -1163,6 +1178,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
# prosody embedding
|
# prosody embedding
|
||||||
pros_emb = None
|
pros_emb = None
|
||||||
|
vae_outputs = None
|
||||||
l_pros_speaker = None
|
l_pros_speaker = None
|
||||||
l_pros_emotion = None
|
l_pros_emotion = None
|
||||||
if self.args.use_prosody_encoder:
|
if self.args.use_prosody_encoder:
|
||||||
|
@ -1224,11 +1240,11 @@ class Vits(BaseTTS):
|
||||||
else:
|
else:
|
||||||
g_dp = torch.cat([g_dp, pros_emb], dim=1) # [b, h1+h2, 1]
|
g_dp = torch.cat([g_dp, pros_emb], dim=1) # [b, h1+h2, 1]
|
||||||
|
|
||||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
outputs, attn, predicted_durations = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
||||||
|
|
||||||
# expand prior
|
# expand prior
|
||||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||||
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||||
|
|
||||||
# select a random feature segment for the waveform decoder
|
# select a random feature segment for the waveform decoder
|
||||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
||||||
|
@ -1267,12 +1283,49 @@ class Vits(BaseTTS):
|
||||||
else:
|
else:
|
||||||
gt_cons_emb, syn_cons_emb = None, None
|
gt_cons_emb, syn_cons_emb = None, None
|
||||||
|
|
||||||
|
end2end_dict = None
|
||||||
|
if self.args.use_end2end_loss:
|
||||||
|
# predicted_durations
|
||||||
|
w = torch.exp(predicted_durations) * x_mask * self.length_scale
|
||||||
|
w_ceil = torch.ceil(w)
|
||||||
|
y_lengths_end2end = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||||
|
y_mask_end2end = sequence_mask(y_lengths_end2end, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec]
|
||||||
|
|
||||||
|
attn_mask = x_mask * y_mask_end2end.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1]
|
||||||
|
attn_end2end = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
||||||
|
|
||||||
|
m_p_end2end = torch.matmul(attn_end2end.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
||||||
|
logs_p_end2end = torch.matmul(attn_end2end.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
z_p_end2end = m_p_end2end * y_mask_end2end #+ torch.randn_like(m_p_end2end) * torch.exp(logs_p_end2end) * self.inference_noise_scale
|
||||||
|
|
||||||
|
# conditional module
|
||||||
|
if self.args.use_prosody_conditional_flow_module:
|
||||||
|
if self.args.prosody_conditional_flow_module_on_decoder:
|
||||||
|
z_p_end2end = self.prosody_conditional_module(z_p_end2end, y_mask_end2end, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb, reverse=True)
|
||||||
|
|
||||||
|
z_end2end = self.flow(z_p_end2end, y_mask_end2end, g=g, reverse=True)
|
||||||
|
|
||||||
|
# interpolate z if needed
|
||||||
|
z_end2end, _, _, y_mask_end2end = self.upsampling_z(z, y_lengths=y_lengths_end2end, y_mask=y_mask_end2end)
|
||||||
|
# z_slice_end2end, spec_segment_size, slice_ids_end2end, _ = self.upsampling_z(z_slice_end2end, slice_ids=slice_ids_end2end)
|
||||||
|
|
||||||
|
# generate all z using the vocoder
|
||||||
|
o_end2end = self.waveform_decoder(z_end2end, g=g)
|
||||||
|
wav_seg_end2end = waveform
|
||||||
|
|
||||||
|
z_predicted_gt_durations = None
|
||||||
|
if not self.args.use_soft_dtw:
|
||||||
|
z_predicted_gt_durations = self.flow(m_p_expanded * y_mask, y_mask, g=g, reverse=True)
|
||||||
|
|
||||||
|
end2end_dict = {"logs_p_end2end": logs_p_end2end, "logs_p": logs_p_expanded, "logs_q": logs_q, "z_mask": y_mask, "z_mask_end2end": y_mask_end2end, "z": z, "z_predicted_end2end": z_end2end, "z_predicted": z_predicted_gt_durations, "model_outputs": o_end2end, "waveform_seg": wav_seg_end2end}
|
||||||
|
|
||||||
outputs.update(
|
outputs.update(
|
||||||
{
|
{
|
||||||
"model_outputs": o,
|
"model_outputs": o,
|
||||||
"alignments": attn.squeeze(1),
|
"alignments": attn.squeeze(1),
|
||||||
"m_p": m_p,
|
"m_p": m_p_expanded,
|
||||||
"logs_p": logs_p,
|
"logs_p": logs_p_expanded,
|
||||||
"z": z,
|
"z": z,
|
||||||
"z_p": z_p,
|
"z_p": z_p,
|
||||||
"m_q": m_q,
|
"m_q": m_q,
|
||||||
|
@ -1286,6 +1339,7 @@ class Vits(BaseTTS):
|
||||||
"loss_prosody_enc_emo_classifier": l_pros_emotion,
|
"loss_prosody_enc_emo_classifier": l_pros_emotion,
|
||||||
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
||||||
"loss_text_enc_emo_classifier": l_text_emotion,
|
"loss_text_enc_emo_classifier": l_text_emotion,
|
||||||
|
"end2end_info": end2end_dict,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -1581,11 +1635,19 @@ class Vits(BaseTTS):
|
||||||
outputs["model_outputs"].detach(), outputs["waveform_seg"]
|
outputs["model_outputs"].detach(), outputs["waveform_seg"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
end2end_info = None
|
||||||
|
if self.args.use_end2end_loss:
|
||||||
|
scores_disc_fake_end2end, _, scores_disc_real_end2end, _ = self.disc(
|
||||||
|
outputs["end2end_info"]["model_outputs"].detach(), self.model_outputs_cache["end2end_info"]["waveform_seg"]
|
||||||
|
)
|
||||||
|
end2end_info = {"scores_disc_real": scores_disc_real_end2end, "scores_disc_fake": scores_disc_fake_end2end}
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
with autocast(enabled=False): # use float32 for the criterion
|
with autocast(enabled=False): # use float32 for the criterion
|
||||||
loss_dict = criterion[optimizer_idx](
|
loss_dict = criterion[optimizer_idx](
|
||||||
scores_disc_real,
|
scores_disc_real,
|
||||||
scores_disc_fake,
|
scores_disc_fake,
|
||||||
|
end2end_info=end2end_info,
|
||||||
)
|
)
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
|
@ -1620,6 +1682,14 @@ class Vits(BaseTTS):
|
||||||
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"]
|
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.args.use_end2end_loss:
|
||||||
|
scores_disc_fake_end2end, feats_disc_fake_end2end, _, feats_disc_real_end2end = self.disc(
|
||||||
|
self.model_outputs_cache["end2end_info"]["model_outputs"], self.model_outputs_cache["end2end_info"]["waveform_seg"]
|
||||||
|
)
|
||||||
|
self.model_outputs_cache["end2end_info"]["scores_disc_fake"] = scores_disc_fake_end2end
|
||||||
|
self.model_outputs_cache["end2end_info"]["feats_disc_fake"] = feats_disc_fake_end2end
|
||||||
|
self.model_outputs_cache["end2end_info"]["feats_disc_real"] = feats_disc_real_end2end
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
with autocast(enabled=False): # use float32 for the criterion
|
with autocast(enabled=False): # use float32 for the criterion
|
||||||
loss_dict = criterion[optimizer_idx](
|
loss_dict = criterion[optimizer_idx](
|
||||||
|
@ -1643,6 +1713,7 @@ class Vits(BaseTTS):
|
||||||
loss_prosody_enc_emo_classifier=self.model_outputs_cache["loss_prosody_enc_emo_classifier"],
|
loss_prosody_enc_emo_classifier=self.model_outputs_cache["loss_prosody_enc_emo_classifier"],
|
||||||
loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"],
|
loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"],
|
||||||
loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"],
|
loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"],
|
||||||
|
end2end_info=self.model_outputs_cache["end2end_info"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.model_outputs_cache, loss_dict
|
return self.model_outputs_cache, loss_dict
|
||||||
|
|
|
@ -52,6 +52,9 @@ config.model_args.use_prosody_encoder_z_p_input = True
|
||||||
config.model_args.prosody_encoder_type = "vae"
|
config.model_args.prosody_encoder_type = "vae"
|
||||||
config.model_args.detach_prosody_enc_input = True
|
config.model_args.detach_prosody_enc_input = True
|
||||||
|
|
||||||
|
# enable end2end loss
|
||||||
|
config.model_args.use_end2end_loss = True
|
||||||
|
|
||||||
config.mixed_precision = False
|
config.mixed_precision = False
|
||||||
|
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
|
|
Loading…
Reference in New Issue