From 5bd59a60235a47bbc1b811f8266deffde5fcbefa Mon Sep 17 00:00:00 2001
From: Edresson Casanova <edresson1@gmail.com>
Date: Mon, 6 Jun 2022 15:10:00 -0300
Subject: [PATCH] Remove VITS End2End loss

---
 TTS/tts/layers/losses.py | 40 +-------------------
 TTS/tts/models/vits.py   | 80 +++-------------------------------------
 2 files changed, 6 insertions(+), 114 deletions(-)

diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py
index 89a11139..0d8fbe0b 100644
--- a/TTS/tts/layers/losses.py
+++ b/TTS/tts/layers/losses.py
@@ -685,7 +685,6 @@ class VitsGeneratorLoss(nn.Module):
         scores_disc_mp=None,
         feats_disc_mp=None,
         feats_disc_zp=None,
-        end2end_info=None,
     ):
         """
         Shapes:
@@ -762,32 +761,6 @@ class VitsGeneratorLoss(nn.Module):
             loss += kl_vae_loss
             return_dict["loss_kl_vae"] = kl_vae_loss
 
-        if end2end_info is not None:
-
-            # gen loss
-            loss_gen_end2end = self.generator_loss(scores_fake=end2end_info["scores_disc_fake"])[0] * self.gen_loss_alpha
-            return_dict["loss_gen_end2end"] = loss_gen_end2end
-            loss += loss_gen_end2end
-
-            # if do not uses soft dtw
-            if end2end_info["z_predicted"] is not None:
-                # loss KL using GT durations
-                z = end2end_info["z"].float()
-                logs_q = end2end_info["logs_q"].float()
-                z_predicted = end2end_info["z_predicted"].float()
-                logs_p = end2end_info["logs_p"].float()
-                z_mask = end2end_info["z_mask"].float()
-
-                kl = logs_p - logs_q - 0.5
-                kl += 0.5 * ((z - z_predicted) ** 2) * torch.exp(-2.0 * logs_p)
-                kl = torch.sum(kl * z_mask)
-                loss_kl_end2end_gt_durations = kl / torch.sum(z_mask)
-                return_dict["loss_kl_end2end_gt_durations"] = loss_kl_end2end_gt_durations
-                loss += loss_kl_end2end_gt_durations
-            else:
-                pass
-                # ToDo: implement soft dtw
-
         # pass losses to the dict
         return_dict["loss_gen"] = loss_gen
         return_dict["loss_kl"] = loss_kl
@@ -822,7 +795,7 @@ class VitsDiscriminatorLoss(nn.Module):
             fake_losses.append(fake_loss.item())
         return loss, real_losses, fake_losses
 
-    def forward(self, scores_disc_real, scores_disc_fake, scores_disc_zp=None, scores_disc_mp=None, end2end_info=None):
+    def forward(self, scores_disc_real, scores_disc_fake, scores_disc_zp=None, scores_disc_mp=None):
         return_dict = {}
         return_dict["loss"] = 0.0
         loss_disc, loss_disc_real, _ = self.discriminator_loss(
@@ -844,17 +817,6 @@ class VitsDiscriminatorLoss(nn.Module):
             return_dict["loss_disc_latent"] = loss_disc_latent * self.disc_latent_loss_alpha
             return_dict["loss"] += return_dict["loss_disc_latent"]
 
-        if end2end_info is not None:
-            loss_disc_end2end, loss_disc_real_end2end, _ = self.discriminator_loss(
-                scores_real=end2end_info["scores_disc_real"], scores_fake=end2end_info["scores_disc_fake"],
-            )
-            return_dict["loss_disc_end2end"] = loss_disc_end2end * self.disc_loss_alpha
-            return_dict["loss"] += return_dict["loss_disc_end2end"]
-
-            for i, ldr in enumerate(loss_disc_real_end2end):
-                return_dict[f"loss_disc_end2end_real_{i}"] = ldr
-
-
         return return_dict
 
 
diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py
index ca5d607e..526f8e7b 100644
--- a/TTS/tts/models/vits.py
+++ b/TTS/tts/models/vits.py
@@ -562,10 +562,6 @@ class VitsArgs(Coqpit):
     use_prosody_conditional_flow_module: bool = False
     prosody_conditional_flow_module_on_decoder: bool = False
 
-    # end 2 end loss
-    use_end2end_loss: bool = False
-    use_soft_dtw: bool = False
-
     use_latent_discriminator: bool = False
 
     detach_dp_input: bool = True
@@ -1069,7 +1065,6 @@ class Vits(BaseTTS):
         return g
 
     def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
-        predicted_durations = None
         # find the alignment path
         attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
         with torch.no_grad():
@@ -1092,15 +1087,6 @@ class Vits(BaseTTS):
                 lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
             )
             loss_duration = loss_duration / torch.sum(x_mask)
-            if self.args.use_end2end_loss:
-                predicted_durations = self.duration_predictor(
-                    x.detach() if self.args.detach_dp_input else x,
-                    x_mask,
-                    g=g.detach() if self.args.detach_dp_input and g is not None else g,
-                    reverse=True,
-                    noise_scale=self.inference_noise_scale_dp,
-                    lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb
-                )
         else:
             attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
             log_durations = self.duration_predictor(
@@ -1109,10 +1095,9 @@ class Vits(BaseTTS):
                 g=g.detach() if self.args.detach_dp_input and g is not None else g,
                 lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
             )
-            predicted_durations = log_durations
             loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
         outputs["loss_duration"] = loss_duration
-        return outputs, attn, predicted_durations
+        return outputs, attn
 
     def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None):
         spec_segment_size = self.spec_segment_size
@@ -1268,7 +1253,7 @@ class Vits(BaseTTS):
             else:
                 g_dp = torch.cat([g_dp, pros_emb], dim=1)  # [b, h1+h2, 1]
 
-        outputs, attn, predicted_durations = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
+        outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
 
         # expand prior
         m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
@@ -1322,43 +1307,6 @@ class Vits(BaseTTS):
         else:
             gt_cons_emb, syn_cons_emb = None, None
 
-        end2end_dict = None
-        if self.args.use_end2end_loss:
-            # predicted_durations
-            w = torch.exp(predicted_durations) * x_mask * self.length_scale
-            w_ceil = torch.ceil(w)
-            y_lengths_end2end = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
-            y_mask_end2end = sequence_mask(y_lengths_end2end, None).to(x_mask.dtype).unsqueeze(1)  # [B, 1, T_dec]
-
-            attn_mask = x_mask * y_mask_end2end.transpose(1, 2)  # [B, 1, T_enc] * [B, T_dec, 1]
-            attn_end2end = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
-
-            m_p_end2end = torch.matmul(attn_end2end.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
-            logs_p_end2end = torch.matmul(attn_end2end.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
-
-            z_p_end2end = m_p_end2end * y_mask_end2end #+ torch.randn_like(m_p_end2end) * torch.exp(logs_p_end2end) * self.inference_noise_scale
-
-            # conditional module
-            if self.args.use_prosody_conditional_flow_module:
-                if self.args.prosody_conditional_flow_module_on_decoder:
-                    z_p_end2end = self.prosody_conditional_module(z_p_end2end, y_mask_end2end, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb, reverse=True)
-
-            z_end2end = self.flow(z_p_end2end, y_mask_end2end, g=g, reverse=True)
-
-            # interpolate z if needed
-            z_end2end, _, _, y_mask_end2end = self.upsampling_z(z, y_lengths=y_lengths_end2end, y_mask=y_mask_end2end)
-            # z_slice_end2end, spec_segment_size, slice_ids_end2end, _ = self.upsampling_z(z_slice_end2end, slice_ids=slice_ids_end2end)
-
-            # generate all z using the vocoder
-            o_end2end = self.waveform_decoder(z_end2end, g=g)
-            wav_seg_end2end = waveform
-
-            z_predicted_gt_durations = None
-            if not self.args.use_soft_dtw:
-                z_predicted_gt_durations = self.flow(m_p_expanded * y_mask, y_mask, g=g, reverse=True)
-
-            end2end_dict = {"logs_p_end2end": logs_p_end2end, "logs_p": logs_p_expanded, "logs_q": logs_q, "z_mask": y_mask, "z_mask_end2end": y_mask_end2end, "z": z, "z_predicted_end2end": z_end2end, "z_predicted": z_predicted_gt_durations, "model_outputs": o_end2end, "waveform_seg": wav_seg_end2end}
-
         outputs.update(
             {
                 "model_outputs": o,
@@ -1377,8 +1325,7 @@ class Vits(BaseTTS):
                 "loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
                 "loss_prosody_enc_emo_classifier": l_pros_emotion,
                 "loss_text_enc_spk_rev_classifier": l_text_speaker,
-                "loss_text_enc_emo_classifier": l_text_emotion,
-                "end2end_info": end2end_dict,
+                "loss_text_enc_emo_classifier": l_text_emotion
             }
         )
         return outputs
@@ -1686,21 +1633,13 @@ class Vits(BaseTTS):
                 outputs["model_outputs"].detach(), outputs["waveform_seg"], outputs["m_p"].detach(), outputs["z_p"].detach()
             )
 
-            end2end_info = None
-            if self.args.use_end2end_loss:
-                scores_disc_fake_end2end, _, scores_disc_real_end2end, _, _, _, _, _ = self.disc(
-                    outputs["end2end_info"]["model_outputs"].detach(), self.model_outputs_cache["end2end_info"]["waveform_seg"]
-                )
-                end2end_info = {"scores_disc_real": scores_disc_real_end2end, "scores_disc_fake": scores_disc_fake_end2end}
-
             # compute loss
             with autocast(enabled=False):  # use float32 for the criterion
                 loss_dict = criterion[optimizer_idx](
                     scores_disc_real,
                     scores_disc_fake,
                     scores_disc_zp,
-                    scores_disc_mp,
-                    end2end_info=end2end_info,
+                    scores_disc_mp
                 )
             return outputs, loss_dict
 
@@ -1735,14 +1674,6 @@ class Vits(BaseTTS):
                 self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"], self.model_outputs_cache["m_p"], self.model_outputs_cache["z_p"].detach()
             )
 
-            if self.args.use_end2end_loss:
-                scores_disc_fake_end2end, feats_disc_fake_end2end, _, feats_disc_real_end2end, _, _, _, _, _ = self.disc(
-                    self.model_outputs_cache["end2end_info"]["model_outputs"], self.model_outputs_cache["end2end_info"]["waveform_seg"]
-                )
-                self.model_outputs_cache["end2end_info"]["scores_disc_fake"] = scores_disc_fake_end2end
-                self.model_outputs_cache["end2end_info"]["feats_disc_fake"] = feats_disc_fake_end2end
-                self.model_outputs_cache["end2end_info"]["feats_disc_real"] = feats_disc_real_end2end
-
             # compute losses
             with autocast(enabled=False):  # use float32 for the criterion
                 loss_dict = criterion[optimizer_idx](
@@ -1768,8 +1699,7 @@ class Vits(BaseTTS):
                     loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"],
                     scores_disc_mp=scores_disc_mp,
                     feats_disc_mp=feats_disc_mp,
-                    feats_disc_zp=feats_disc_zp,
-                    end2end_info=self.model_outputs_cache["end2end_info"],
+                    feats_disc_zp=feats_disc_zp
                 )
 
             return self.model_outputs_cache, loss_dict