From 690b37d0abbe8e225c48494618eb1e96625ac17a Mon Sep 17 00:00:00 2001 From: Edresson Date: Fri, 3 Sep 2021 07:37:43 -0300 Subject: [PATCH] Add support to use the speaker encoder as loss function in VITS model --- TTS/tts/configs/vits_config.py | 1 + TTS/tts/layers/losses.py | 10 ++++++ TTS/tts/models/vits.py | 57 +++++++++++++++++++++++++++------- 3 files changed, 56 insertions(+), 12 deletions(-) diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index cc3e4940..ece414a6 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -117,6 +117,7 @@ class VitsConfig(BaseTTSConfig): feat_loss_alpha: float = 1.0 mel_loss_alpha: float = 45.0 dur_loss_alpha: float = 1.0 + speaker_encoder_loss_alpha: float = 1.0 # data loader params return_wav: bool = True diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 145cd1a0..fdee9c10 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -532,6 +532,7 @@ class VitsGeneratorLoss(nn.Module): self.feat_loss_alpha = c.feat_loss_alpha self.dur_loss_alpha = c.dur_loss_alpha self.mel_loss_alpha = c.mel_loss_alpha + self.spk_encoder_loss_alpha = c.speaker_encoder_loss_alpha self.stft = TorchSTFT( c.audio.fft_size, c.audio.hop_length, @@ -599,6 +600,9 @@ class VitsGeneratorLoss(nn.Module): feats_disc_real, loss_duration, fine_tuning_mode=False, + use_speaker_encoder_as_loss=False, + gt_spk_emb=None, + syn_spk_emb=None ): """ Shapes: @@ -632,6 +636,12 @@ class VitsGeneratorLoss(nn.Module): loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration + + if use_speaker_encoder_as_loss: + loss_se = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha + loss += loss_se + return_dict["loss_spk_encoder"] = loss_se + # pass losses to the dict return_dict["loss_gen"] = loss_gen return_dict["loss_kl"] = loss_kl diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index d355d5c1..71cc4634 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -195,6 +195,10 @@ class VitsArgs(Coqpit): embedded_language_dim: int = 4 num_languages: int = 0 fine_tuning_mode: bool = False + use_speaker_encoder_as_loss: bool = False + speaker_encoder_config_path: str = "" + speaker_encoder_model_path: str = "" + class Vits(BaseTTS): @@ -370,6 +374,18 @@ class Vits(BaseTTS): self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file) self.embedded_speaker_dim = config.d_vector_dim + if config.use_speaker_encoder_as_loss: + if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path: + raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!") + self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path) + self.speaker_encoder = self.speaker_manager.speaker_encoder.train() + for param in self.speaker_encoder.parameters(): + param.requires_grad = False + + print(" > External Speaker Encoder Loaded !!") + else: + self.speaker_encoder = None + def init_multilingual(self, config: Coqpit, data: List = None): """Initialize multilingual modules of a model. @@ -427,6 +443,7 @@ class Vits(BaseTTS): y: torch.tensor, y_lengths: torch.tensor, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, + waveform=None, ) -> Dict: """Forward pass of the model. @@ -461,7 +478,6 @@ class Vits(BaseTTS): x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) - # posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) @@ -508,17 +524,36 @@ class Vits(BaseTTS): # select a random feature segment for the waveform decoder z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size) o = self.waveform_decoder(z_slice, g=g) + + wav_seg = segment( + waveform.transpose(1, 2), + slice_ids * self.config.audio.hop_length, + self.args.spec_segment_size * self.config.audio.hop_length, + ) + + if self.args.use_speaker_encoder_as_loss: + # concate generated and GT waveforms + wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1) + pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True) + + # split generated and GT speaker embeddings + gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) + else: + gt_spk_emb, syn_spk_emb = None, None + outputs.update( { "model_outputs": o, "alignments": attn.squeeze(1), - "slice_ids": slice_ids, "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "m_q": m_q, "logs_q": logs_q, + "waveform_seg": wav_seg, + "gt_spk_emb": gt_spk_emb, + "syn_spk_emb": syn_spk_emb } ) return outputs @@ -596,7 +631,6 @@ class Vits(BaseTTS): { "model_outputs": o, "alignments": attn.squeeze(1), - "slice_ids": slice_ids, "z": z, "z_p": z_p, "m_p": m_p, @@ -713,6 +747,7 @@ class Vits(BaseTTS): linear_input.transpose(1, 2), mel_lengths, aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, + waveform=waveform, ) else: outputs = self.forward( @@ -721,30 +756,25 @@ class Vits(BaseTTS): linear_input.transpose(1, 2), mel_lengths, aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, + waveform=waveform, ) # cache tensors for the discriminator self.y_disc_cache = None self.wav_seg_disc_cache = None self.y_disc_cache = outputs["model_outputs"] - wav_seg = segment( - waveform.transpose(1, 2), - outputs["slice_ids"] * self.config.audio.hop_length, - self.args.spec_segment_size * self.config.audio.hop_length, - ) - self.wav_seg_disc_cache = wav_seg - outputs["waveform_seg"] = wav_seg + self.wav_seg_disc_cache = outputs["waveform_seg"] # compute discriminator scores and features outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc( - outputs["model_outputs"], wav_seg + outputs["model_outputs"], outputs["waveform_seg"] ) # compute losses with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( waveform_hat=outputs["model_outputs"].float(), - waveform=wav_seg.float(), + waveform= outputs["waveform_seg"].float(), z_p=outputs["z_p"].float(), logs_q=outputs["logs_q"].float(), m_p=outputs["m_p"].float(), @@ -755,6 +785,9 @@ class Vits(BaseTTS): feats_disc_real=outputs["feats_disc_real"], loss_duration=outputs["loss_duration"], fine_tuning_mode=self.args.fine_tuning_mode, + use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, + gt_spk_emb=outputs["gt_spk_emb"], + syn_spk_emb=outputs["syn_spk_emb"] ) # ignore duration loss if fine tuning mode is on if not self.args.fine_tuning_mode: