From ac3f98cefb5cf1aea115f83f48f6fa0404c838aa Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 17 May 2022 13:14:15 +0000 Subject: [PATCH] Add text encoder reversal speaker classifier loss --- TTS/tts/configs/vits_config.py | 1 + TTS/tts/layers/losses.py | 16 +++++++++--- TTS/tts/models/vits.py | 25 +++++++++++++++---- ...est_vits_speaker_emb_with_emotion_train.py | 10 ++++---- 4 files changed, 38 insertions(+), 14 deletions(-) diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index 9c2157a9..ff8fcf12 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -115,6 +115,7 @@ class VitsConfig(BaseTTSConfig): mel_loss_alpha: float = 45.0 dur_loss_alpha: float = 1.0 consistency_loss_alpha: float = 1.0 + text_enc_spk_reversal_loss_alpha: float = 2.0 # data loader params return_wav: bool = True diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 1d47745c..14c7b9b5 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -590,6 +590,8 @@ class VitsGeneratorLoss(nn.Module): self.dur_loss_alpha = c.dur_loss_alpha self.mel_loss_alpha = c.mel_loss_alpha self.consistency_loss_alpha = c.consistency_loss_alpha + self.text_enc_spk_reversal_loss_alpha = c.text_enc_spk_reversal_loss_alpha + self.stft = TorchSTFT( c.audio.fft_size, c.audio.hop_length, @@ -662,7 +664,8 @@ class VitsGeneratorLoss(nn.Module): use_encoder_consistency_loss=False, gt_cons_emb=None, syn_cons_emb=None, - loss_spk_reversal_classifier=None, + loss_prosody_enc_spk_rev_classifier=None, + loss_text_enc_spk_rev_classifier=None, ): """ Shapes: @@ -698,9 +701,14 @@ class VitsGeneratorLoss(nn.Module): loss = loss + loss_enc return_dict["loss_consistency_enc"] = loss_enc - if loss_spk_reversal_classifier is not None: - loss += loss_spk_reversal_classifier - return_dict["loss_spk_reversal_classifier"] = loss_spk_reversal_classifier + if loss_prosody_enc_spk_rev_classifier is not None: + loss += loss_prosody_enc_spk_rev_classifier + return_dict["loss_prosody_enc_spk_rev_classifier"] = loss_prosody_enc_spk_rev_classifier + + if loss_text_enc_spk_rev_classifier is not None: + loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.text_enc_spk_reversal_loss_alpha + loss += loss_text_enc_spk_rev_classifier + return_dict["loss_text_enc_spk_rev_classifier"] = loss_text_enc_spk_rev_classifier # pass losses to the dict return_dict["loss_gen"] = loss_gen diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 21333178..1646289b 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -540,6 +540,7 @@ class VitsArgs(Coqpit): external_emotions_embs_file: str = None emotion_embedding_dim: int = 0 num_emotions: int = 0 + use_text_enc_spk_reversal_classifier: bool = False # prosody encoder use_prosody_encoder: bool = False @@ -689,12 +690,19 @@ class Vits(BaseTTS): num_style_tokens=self.args.prosody_encoder_num_tokens, gst_embedding_dim=self.args.prosody_embedding_dim, ) - self.speaker_reversal_classifier = ReversalClassifier( + self.speaker_pros_enc_reversal_classifier = ReversalClassifier( in_channels=self.args.prosody_embedding_dim, out_channels=self.num_speakers, hidden_channels=256, ) + if self.args.use_text_enc_spk_reversal_classifier: + self.speaker_text_enc_reversal_classifier = ReversalClassifier( + in_channels=self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim, + out_channels=self.num_speakers, + hidden_channels=256, + ) + self.waveform_decoder = HifiganGenerator( self.args.hidden_channels, 1, @@ -1081,10 +1089,15 @@ class Vits(BaseTTS): l_pros_speaker = None if self.args.use_prosody_encoder: pros_emb = self.prosody_encoder(z).transpose(1, 2) - _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) + _, l_pros_speaker = self.speaker_pros_enc_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb) + # reversal speaker loss to force the encoder to be speaker identity free + l_text_speaker = None + if self.args.use_text_enc_spk_reversal_classifier: + _, l_text_speaker = self.speaker_text_enc_reversal_classifier(x.transpose(1, 2), sid, x_mask=None) + # flow layers z_p = self.flow(z, y_mask, g=g) @@ -1159,7 +1172,8 @@ class Vits(BaseTTS): "gt_cons_emb": gt_cons_emb, "syn_cons_emb": syn_cons_emb, "slice_ids": slice_ids, - "loss_spk_reversal_classifier": l_pros_speaker, + "loss_prosody_enc_spk_rev_classifier": l_pros_speaker, + "loss_text_enc_spk_rev_classifier": l_text_speaker, } ) return outputs @@ -1465,7 +1479,8 @@ class Vits(BaseTTS): or self.args.use_emotion_encoder_as_loss, gt_cons_emb=self.model_outputs_cache["gt_cons_emb"], syn_cons_emb=self.model_outputs_cache["syn_cons_emb"], - loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"], + loss_prosody_enc_spk_rev_classifier=self.model_outputs_cache["loss_prosody_enc_spk_rev_classifier"], + loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"] ) return self.model_outputs_cache, loss_dict @@ -1636,7 +1651,7 @@ class Vits(BaseTTS): if ( self.speaker_manager is not None and self.speaker_manager.ids - and (self.args.use_speaker_embedding or self.args.use_prosody_encoder) + and (self.args.use_speaker_embedding or self.args.use_prosody_encoder or self.args.use_text_enc_spk_reversal_classifier) ): speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]] diff --git a/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py index f200c806..cd9118ad 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py @@ -34,18 +34,18 @@ config.audio.do_trim_silence = True config.audio.trim_db = 60 # active multispeaker d-vec mode -config.model_args.use_speaker_embedding = True -config.model_args.use_d_vector_file = False +config.model_args.use_speaker_embedding = False +config.model_args.use_d_vector_file = True config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" config.model_args.speaker_embedding_channels = 128 config.model_args.d_vector_dim = 256 # emotion -config.model_args.use_external_emotions_embeddings = False -config.model_args.use_emotion_embedding = True -config.model_args.emotion_just_encoder = False +config.model_args.use_external_emotions_embeddings = True +config.model_args.use_emotion_embedding = False config.model_args.emotion_embedding_dim = 256 config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json" +config.model_args.use_text_enc_spk_reversal_classifier = False # consistency loss # config.model_args.use_emotion_encoder_as_loss = True