From 9071bf326f1bf670e97457d2ef469d90e06021d4 Mon Sep 17 00:00:00 2001 From: Edresson Date: Wed, 25 Aug 2021 16:52:02 -0300 Subject: [PATCH] Implement vocoder Fine Tuning like SC-GlowTTS paper --- TTS/tts/layers/losses.py | 9 ++- TTS/tts/models/vits.py | 140 ++++++++++++++++++++++++++++++++++----- 2 files changed, 133 insertions(+), 16 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 0ea342e8..145cd1a0 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -598,6 +598,7 @@ class VitsGeneratorLoss(nn.Module): feats_disc_fake, feats_disc_real, loss_duration, + fine_tuning_mode=False, ): """ Shapes: @@ -619,9 +620,15 @@ class VitsGeneratorLoss(nn.Module): mel = self.stft(waveform) mel_hat = self.stft(waveform_hat) # compute losses + + # ignore tts model loss if fine tunning mode is on + if fine_tuning_mode: + loss_kl = 0.0 + else: + loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha + loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha - loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index e7305fb8..ce75d6dd 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -193,6 +193,7 @@ class VitsArgs(Coqpit): use_language_embedding: bool = False embedded_language_dim: int = 4 num_languages: int = 0 + fine_tuning_mode: bool = False class Vits(BaseTTS): @@ -330,6 +331,7 @@ class Vits(BaseTTS): if args.init_discriminator: self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator) + print("FINE TUNING:", self.args.fine_tuning_mode) def init_multispeaker(self, config: Coqpit): """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer @@ -521,6 +523,90 @@ class Vits(BaseTTS): ) return outputs + def forward_fine_tuning( + self, + x: torch.tensor, + x_lengths: torch.tensor, + y: torch.tensor, + y_lengths: torch.tensor, + aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, + ) -> Dict: + """Forward pass of the model. + + Args: + x (torch.tensor): Batch of input character sequence IDs. + x_lengths (torch.tensor): Batch of input character sequence lengths. + y (torch.tensor): Batch of input spectrograms. + y_lengths (torch.tensor): Batch of input spectrogram lengths. + aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}. + + Returns: + Dict: model outputs keyed by the output name. + + Shapes: + - x: :math:`[B, T_seq]` + - x_lengths: :math:`[B]` + - y: :math:`[B, C, T_spec]` + - y_lengths: :math:`[B]` + - d_vectors: :math:`[B, C, 1]` + - speaker_ids: :math:`[B]` + """ + with torch.no_grad(): + outputs = {} + sid, g, lid = self._set_cond_input(aux_input) + # speaker embedding + if self.args.use_speaker_embedding and sid is not None and not self.use_d_vector: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + + # language embedding + lang_emb=None + if self.args.use_language_embedding and lid is not None: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + + # posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) + + # flow layers + z_p = self.flow(z, y_mask, g=g) + + # find the alignment path + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + with torch.no_grad(): + o_scale = torch.exp(-2 * logs_p) + # logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)]) + logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) + # logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp2 + logp3 + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + + # expand prior + m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) + logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) + + # get the z after inverse decoder + # ToDo: test if using m_p the result is better (In the SC-GlowTTS paper we used mp instead z_p) + z_f_pred = self.flow(z_p, y_mask, g=g, reverse=True) + z_slice, slice_ids = rand_segment(z_f_pred, y_lengths, self.spec_segment_size) + + o = self.waveform_decoder(z_slice, g=g) + outputs.update( + { + "model_outputs": o, + "alignments": attn.squeeze(1), + "slice_ids": slice_ids, + "z": z, + "z_p": z_p, + "m_p": m_p, + "logs_p": logs_p, + "m_q": m_q, + "logs_q": logs_q, + } + ) + return outputs + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}): """ Shapes: @@ -599,6 +685,15 @@ class Vits(BaseTTS): if optimizer_idx not in [0, 1]: raise ValueError(" [!] Unexpected `optimizer_idx`.") + # generator pass + if self.args.fine_tuning_mode: + # ToDo: find better place fot it + # force eval mode + self.eval() + # restore train mode for the vocoder part + self.waveform_decoder.train() + self.disc.train() + if optimizer_idx == 0: text_input = batch["text_input"] text_lengths = batch["text_lengths"] @@ -610,13 +705,24 @@ class Vits(BaseTTS): waveform = batch["waveform"] # generator pass - outputs = self.forward( - text_input, - text_lengths, - linear_input.transpose(1, 2), - mel_lengths, - aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, - ) + if self.args.fine_tuning_mode: + + # model forward + outputs = self.forward_fine_tuning( + text_input, + text_lengths, + linear_input.transpose(1, 2), + mel_lengths, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, + ) + else: + outputs = self.forward( + text_input, + text_lengths, + linear_input.transpose(1, 2), + mel_lengths, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, + ) # cache tensors for the discriminator self.y_disc_cache = None @@ -649,15 +755,17 @@ class Vits(BaseTTS): feats_disc_fake=outputs["feats_disc_fake"], feats_disc_real=outputs["feats_disc_real"], loss_duration=outputs["loss_duration"], + fine_tuning_mode=self.args.fine_tuning_mode, ) - - # handle the duration loss - if self.args.use_sdp: - loss_dict["nll_duration"] = outputs["nll_duration"] - loss_dict["loss"] += outputs["nll_duration"] - else: - loss_dict["loss_duration"] = outputs["loss_duration"] - loss_dict["loss"] += outputs["loss_duration"] + # ignore duration loss if fine tuning mode is on + if not self.args.fine_tuning_mode: + # handle the duration loss + if self.args.use_sdp: + loss_dict["nll_duration"] = outputs["nll_duration"] + loss_dict["loss"] += outputs["nll_duration"] + else: + loss_dict["loss_duration"] = outputs["loss_duration"] + loss_dict["loss"] += outputs["loss_duration"] elif optimizer_idx == 1: # discriminator pass @@ -853,3 +961,5 @@ class Vits(BaseTTS): if eval: self.eval() assert not self.training + +