From 4e94b46d5e1833f32ea1a79128e1bd6139015544 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 2 Jun 2022 13:50:08 -0300 Subject: [PATCH] Add end2end VITS loss --- TTS/tts/datasets/formatters.py | 7 +- TTS/tts/layers/losses.py | 47 ++++++++++- TTS/tts/models/vits.py | 83 +++++++++++++++++-- ...t_vits_speaker_emb_with_prosody_encoder.py | 3 + 4 files changed, 132 insertions(+), 8 deletions(-) diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 4398c960..28c3b956 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -329,7 +329,12 @@ def brspeech(root_path, meta_file, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker_id in ignored_speakers: continue - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id}) + + if os.path.exists(wav_file): + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id}) + else: + print(f" [!] wav files don't exist - {wav_file}") + return items diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 00b00b77..2fb39a63 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -671,6 +671,7 @@ class VitsGeneratorLoss(nn.Module): loss_prosody_enc_emo_classifier=None, loss_text_enc_spk_rev_classifier=None, loss_text_enc_emo_classifier=None, + end2end_info=None, ): """ Shapes: @@ -736,7 +737,39 @@ class VitsGeneratorLoss(nn.Module): loss += kl_vae_loss return_dict["loss_kl_vae"] = kl_vae_loss + + if end2end_info is not None: + # do not compute feature loss because for it we need waves segments with the same length + '''loss_feat_end2end = ( + self.feature_loss(feats_real=end2end_info["feats_disc_real"], feats_generated=end2end_info["feats_disc_fake"]) * self.feat_loss_alpha + ) + return_dict["loss_feat_end2end"] = loss_feat_end2end + loss += loss_feat_end2end''' + # gen loss + loss_gen_end2end = self.generator_loss(scores_fake=end2end_info["scores_disc_fake"])[0] * self.gen_loss_alpha + return_dict["loss_gen_end2end"] = loss_gen_end2end + loss += loss_gen_end2end + + # if do not uses soft dtw + if end2end_info["z_predicted"] is not None: + # loss KL using GT durations + z = end2end_info["z"].float() + logs_q = end2end_info["logs_q"].float() + z_predicted = end2end_info["z_predicted"].float() + logs_p = end2end_info["logs_p"].float() + z_mask = end2end_info["z_mask"].float() + + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z - z_predicted) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + loss_kl_end2end_gt_durations = kl / torch.sum(z_mask) + return_dict["loss_kl_end2end_gt_durations"] = loss_kl_end2end_gt_durations + loss += loss_kl_end2end_gt_durations + else: + pass + # ToDo: implement soft dtw + # pass losses to the dict return_dict["loss_gen"] = loss_gen return_dict["loss_kl"] = loss_kl @@ -767,7 +800,7 @@ class VitsDiscriminatorLoss(nn.Module): fake_losses.append(fake_loss.item()) return loss, real_losses, fake_losses - def forward(self, scores_disc_real, scores_disc_fake): + def forward(self, scores_disc_real, scores_disc_fake, end2end_info=None): loss = 0.0 return_dict = {} loss_disc, loss_disc_real, _ = self.discriminator_loss( @@ -779,6 +812,18 @@ class VitsDiscriminatorLoss(nn.Module): for i, ldr in enumerate(loss_disc_real): return_dict[f"loss_disc_real_{i}"] = ldr + + if end2end_info is not None: + loss_disc_end2end, loss_disc_real_end2end, _ = self.discriminator_loss( + scores_real=end2end_info["scores_disc_real"], scores_fake=end2end_info["scores_disc_fake"], + ) + return_dict["loss_disc_end2end"] = loss_disc_end2end * self.disc_loss_alpha + return_dict["loss"] += return_dict["loss_disc_end2end"] + + for i, ldr in enumerate(loss_disc_real_end2end): + return_dict[f"loss_disc_end2end_real_{i}"] = ldr + + return return_dict diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 41808866..723ad2ab 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -559,6 +559,10 @@ class VitsArgs(Coqpit): use_prosody_conditional_flow_module: bool = False prosody_conditional_flow_module_on_decoder: bool = False + # end 2 end loss + use_end2end_loss: bool = False + use_soft_dtw: bool = False + detach_dp_input: bool = True use_language_embedding: bool = False embedded_language_dim: int = 4 @@ -1037,6 +1041,7 @@ class Vits(BaseTTS): return g def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): + predicted_durations = None # find the alignment path attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) with torch.no_grad(): @@ -1059,6 +1064,15 @@ class Vits(BaseTTS): lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) loss_duration = loss_duration / torch.sum(x_mask) + if self.args.use_end2end_loss: + predicted_durations = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + reverse=True, + noise_scale=self.inference_noise_scale_dp, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb + ) else: attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask log_durations = self.duration_predictor( @@ -1067,9 +1081,10 @@ class Vits(BaseTTS): g=g.detach() if self.args.detach_dp_input and g is not None else g, lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) + predicted_durations = log_durations loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) outputs["loss_duration"] = loss_duration - return outputs, attn + return outputs, attn, predicted_durations def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None): spec_segment_size = self.spec_segment_size @@ -1163,6 +1178,7 @@ class Vits(BaseTTS): # prosody embedding pros_emb = None + vae_outputs = None l_pros_speaker = None l_pros_emotion = None if self.args.use_prosody_encoder: @@ -1224,11 +1240,11 @@ class Vits(BaseTTS): else: g_dp = torch.cat([g_dp, pros_emb], dim=1) # [b, h1+h2, 1] - outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb) + outputs, attn, predicted_durations = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb) # expand prior - m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) - logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) + m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) + logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) # select a random feature segment for the waveform decoder z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) @@ -1267,12 +1283,49 @@ class Vits(BaseTTS): else: gt_cons_emb, syn_cons_emb = None, None + end2end_dict = None + if self.args.use_end2end_loss: + # predicted_durations + w = torch.exp(predicted_durations) * x_mask * self.length_scale + w_ceil = torch.ceil(w) + y_lengths_end2end = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask_end2end = sequence_mask(y_lengths_end2end, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec] + + attn_mask = x_mask * y_mask_end2end.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] + attn_end2end = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) + + m_p_end2end = torch.matmul(attn_end2end.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) + logs_p_end2end = torch.matmul(attn_end2end.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) + + z_p_end2end = m_p_end2end * y_mask_end2end #+ torch.randn_like(m_p_end2end) * torch.exp(logs_p_end2end) * self.inference_noise_scale + + # conditional module + if self.args.use_prosody_conditional_flow_module: + if self.args.prosody_conditional_flow_module_on_decoder: + z_p_end2end = self.prosody_conditional_module(z_p_end2end, y_mask_end2end, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb, reverse=True) + + z_end2end = self.flow(z_p_end2end, y_mask_end2end, g=g, reverse=True) + + # interpolate z if needed + z_end2end, _, _, y_mask_end2end = self.upsampling_z(z, y_lengths=y_lengths_end2end, y_mask=y_mask_end2end) + # z_slice_end2end, spec_segment_size, slice_ids_end2end, _ = self.upsampling_z(z_slice_end2end, slice_ids=slice_ids_end2end) + + # generate all z using the vocoder + o_end2end = self.waveform_decoder(z_end2end, g=g) + wav_seg_end2end = waveform + + z_predicted_gt_durations = None + if not self.args.use_soft_dtw: + z_predicted_gt_durations = self.flow(m_p_expanded * y_mask, y_mask, g=g, reverse=True) + + end2end_dict = {"logs_p_end2end": logs_p_end2end, "logs_p": logs_p_expanded, "logs_q": logs_q, "z_mask": y_mask, "z_mask_end2end": y_mask_end2end, "z": z, "z_predicted_end2end": z_end2end, "z_predicted": z_predicted_gt_durations, "model_outputs": o_end2end, "waveform_seg": wav_seg_end2end} + outputs.update( { "model_outputs": o, "alignments": attn.squeeze(1), - "m_p": m_p, - "logs_p": logs_p, + "m_p": m_p_expanded, + "logs_p": logs_p_expanded, "z": z, "z_p": z_p, "m_q": m_q, @@ -1286,6 +1339,7 @@ class Vits(BaseTTS): "loss_prosody_enc_emo_classifier": l_pros_emotion, "loss_text_enc_spk_rev_classifier": l_text_speaker, "loss_text_enc_emo_classifier": l_text_emotion, + "end2end_info": end2end_dict, } ) return outputs @@ -1581,11 +1635,19 @@ class Vits(BaseTTS): outputs["model_outputs"].detach(), outputs["waveform_seg"] ) + end2end_info = None + if self.args.use_end2end_loss: + scores_disc_fake_end2end, _, scores_disc_real_end2end, _ = self.disc( + outputs["end2end_info"]["model_outputs"].detach(), self.model_outputs_cache["end2end_info"]["waveform_seg"] + ) + end2end_info = {"scores_disc_real": scores_disc_real_end2end, "scores_disc_fake": scores_disc_fake_end2end} + # compute loss with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( scores_disc_real, scores_disc_fake, + end2end_info=end2end_info, ) return outputs, loss_dict @@ -1620,6 +1682,14 @@ class Vits(BaseTTS): self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] ) + if self.args.use_end2end_loss: + scores_disc_fake_end2end, feats_disc_fake_end2end, _, feats_disc_real_end2end = self.disc( + self.model_outputs_cache["end2end_info"]["model_outputs"], self.model_outputs_cache["end2end_info"]["waveform_seg"] + ) + self.model_outputs_cache["end2end_info"]["scores_disc_fake"] = scores_disc_fake_end2end + self.model_outputs_cache["end2end_info"]["feats_disc_fake"] = feats_disc_fake_end2end + self.model_outputs_cache["end2end_info"]["feats_disc_real"] = feats_disc_real_end2end + # compute losses with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( @@ -1643,6 +1713,7 @@ class Vits(BaseTTS): loss_prosody_enc_emo_classifier=self.model_outputs_cache["loss_prosody_enc_emo_classifier"], loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"], loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"], + end2end_info=self.model_outputs_cache["end2end_info"], ) return self.model_outputs_cache, loss_dict diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py index 4fe3a8b1..6fa4a536 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -52,6 +52,9 @@ config.model_args.use_prosody_encoder_z_p_input = True config.model_args.prosody_encoder_type = "vae" config.model_args.detach_prosody_enc_input = True +# enable end2end loss +config.model_args.use_end2end_loss = True + config.mixed_precision = False config.save_json(config_path)