From 5090034fd16e8846a80729113083d13960baca76 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 15 Mar 2022 12:35:00 +0000 Subject: [PATCH] Add emotion consistency loss --- TTS/tts/configs/vits_config.py | 2 +- TTS/tts/layers/losses.py | 20 ++-- TTS/tts/models/vits.py | 93 ++++++++++++------- TTS/utils/manage.py | 12 +-- ...est_vits_speaker_emb_with_emotion_train.py | 13 ++- 5 files changed, 88 insertions(+), 52 deletions(-) diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index a8c7f91d..9c2157a9 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -114,7 +114,7 @@ class VitsConfig(BaseTTSConfig): feat_loss_alpha: float = 1.0 mel_loss_alpha: float = 45.0 dur_loss_alpha: float = 1.0 - speaker_encoder_loss_alpha: float = 1.0 + consistency_loss_alpha: float = 1.0 # data loader params return_wav: bool = True diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index e03cf084..6871c0ef 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -532,7 +532,7 @@ class VitsGeneratorLoss(nn.Module): self.feat_loss_alpha = c.feat_loss_alpha self.dur_loss_alpha = c.dur_loss_alpha self.mel_loss_alpha = c.mel_loss_alpha - self.spk_encoder_loss_alpha = c.speaker_encoder_loss_alpha + self.consistency_loss_alpha = c.consistency_loss_alpha self.stft = TorchSTFT( c.audio.fft_size, c.audio.hop_length, @@ -586,8 +586,8 @@ class VitsGeneratorLoss(nn.Module): return l @staticmethod - def cosine_similarity_loss(gt_spk_emb, syn_spk_emb): - return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() + def cosine_similarity_loss(gt_cons_emb, syn_cons_emb): + return -torch.nn.functional.cosine_similarity(gt_cons_emb, syn_cons_emb).mean() def forward( self, @@ -602,9 +602,9 @@ class VitsGeneratorLoss(nn.Module): feats_disc_fake, feats_disc_real, loss_duration, - use_speaker_encoder_as_loss=False, - gt_spk_emb=None, - syn_spk_emb=None, + use_encoder_consistency_loss=False, + gt_cons_emb=None, + syn_cons_emb=None, ): """ Shapes: @@ -635,10 +635,10 @@ class VitsGeneratorLoss(nn.Module): loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration - if use_speaker_encoder_as_loss: - loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha - loss = loss + loss_se - return_dict["loss_spk_encoder"] = loss_se + if use_encoder_consistency_loss: + loss_enc = self.cosine_similarity_loss(gt_cons_emb, syn_cons_emb) * self.consistency_loss_alpha + loss = loss + loss_enc + return_dict["loss_consistency_enc"] = loss_enc # pass losses to the dict return_dict["loss_gen"] = loss_gen return_dict["loss_kl"] = loss_kl diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index f2a2f2b1..5848f295 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -429,10 +429,10 @@ class VitsArgs(Coqpit): use_speaker_encoder_as_loss (bool): Enable/Disable Speaker Consistency Loss (SCL). Defaults to False. - speaker_encoder_config_path (str): + encoder_config_path (str): Path to the file speaker encoder config file, to use for SCL. Defaults to "". - speaker_encoder_model_path (str): + encoder_model_path (str): Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". condition_dp_on_speaker (bool): @@ -507,8 +507,9 @@ class VitsArgs(Coqpit): num_languages: int = 0 language_ids_file: str = None use_speaker_encoder_as_loss: bool = False - speaker_encoder_config_path: str = "" - speaker_encoder_model_path: str = "" + use_emotion_encoder_as_loss: bool = False + encoder_config_path: str = "" + encoder_model_path: str = "" condition_dp_on_speaker: bool = True freeze_encoder: bool = False freeze_DP: bool = False @@ -559,6 +560,7 @@ class Vits(BaseTTS): self.init_multispeaker(config) self.init_multilingual(config) self.init_emotion(config, emotion_manager) + self.init_consistency_loss() self.length_scale = self.args.length_scale self.noise_scale = self.args.noise_scale @@ -661,15 +663,21 @@ class Vits(BaseTTS): if self.args.use_d_vector_file: self._init_d_vector() - # TODO: make this a function - if self.args.use_speaker_encoder_as_loss: - if self.speaker_manager.encoder is None and ( - not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path - ): - raise RuntimeError( - " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" + def init_consistency_loss(self): + if self.args.use_speaker_encoder_as_loss and self.args.use_emotion_encoder_as_loss: + raise RuntimeError( + " [!] The use of speaker consistency loss (SCL) and emotion consistency loss (ECL) together is not supported, please disable one of those !!" ) + if self.args.use_speaker_encoder_as_loss: + if self.speaker_manager.encoder is None and ( + not self.args.encoder_model_path or not self.args.encoder_config_path + ): + raise RuntimeError( + " [!] To use the speaker consistency loss (SCL) you need to specify encoder_model_path and encoder_config_path !!" + ) + # load encoder + self.speaker_manager.init_encoder(self.args.encoder_model_path, self.args.encoder_config_path) self.speaker_manager.encoder.eval() print(" > External Speaker Encoder Loaded !!") @@ -677,15 +685,33 @@ class Vits(BaseTTS): hasattr(self.speaker_manager.encoder, "audio_config") and self.config.audio["sample_rate"] != self.speaker_manager.encoder.audio_config["sample_rate"] ): + # pylint: disable=W0101,W0105 self.audio_transform = torchaudio.transforms.Resample( - orig_freq=self.audio_config["sample_rate"], + orig_freq=self.config.audio["sample_rate"], new_freq=self.speaker_manager.encoder.audio_config["sample_rate"], ) - # pylint: disable=W0101,W0105 - self.audio_transform = torchaudio.transforms.Resample( - orig_freq=self.config.audio.sample_rate, - new_freq=self.speaker_manager.encoder.audio_config["sample_rate"], - ) + + elif self.args.use_emotion_encoder_as_loss: + if self.emotion_manager.encoder is None and ( + not self.args.encoder_model_path or not self.args.encoder_config_path + ): + raise RuntimeError( + " [!] To use the emotion consistency loss (ECL) you need to specify encoder_model_path and encoder_config_path !!" + ) + # load encoder + self.emotion_manager.init_encoder(self.args.encoder_model_path, self.args.encoder_config_path) + self.emotion_manager.encoder.eval() + print(" > External Emotion Encoder Loaded !!") + + if ( + hasattr(self.emotion_manager.encoder, "audio_config") + and self.config.audio["sample_rate"] != self.emotion_manager.encoder.audio_config["sample_rate"] + ): + # pylint: disable=W0101,W0105 + self.audio_transform = torchaudio.transforms.Resample( + orig_freq=self.config.audio["sample_rate"], + new_freq=self.emotion_manager.encoder.audio_config["sample_rate"], + ) def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init @@ -896,8 +922,8 @@ class Vits(BaseTTS): - m_q: :math:`[B, C, T_dec]` - logs_q: :math:`[B, C, T_dec]` - waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]` - - gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` - - syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` + - gt_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]` + - syn_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]` """ outputs = {} sid, g, lid, eid, eg = self._set_cond_input(aux_input) @@ -944,7 +970,8 @@ class Vits(BaseTTS): pad_short=True, ) - if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None: + if self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss: + encoder = self.speaker_manager.encoder if self.args.use_speaker_encoder_as_loss else self.emotion_manager.encoder # concate generated and GT waveforms wavs_batch = torch.cat((wav_seg, o), dim=0) @@ -953,12 +980,15 @@ class Vits(BaseTTS): if self.audio_transform is not None: wavs_batch = self.audio_transform(wavs_batch) - pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True) + if next(encoder.parameters()).device != wavs_batch.device: + encoder = encoder.to(wavs_batch.device) + + pred_embs = encoder.forward(wavs_batch, l2_norm=True) # split generated and GT speaker embeddings - gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) + gt_cons_emb, syn_cons_emb = torch.chunk(pred_embs, 2, dim=0) else: - gt_spk_emb, syn_spk_emb = None, None + gt_cons_emb, syn_cons_emb = None, None outputs.update( { @@ -971,8 +1001,8 @@ class Vits(BaseTTS): "m_q": m_q, "logs_q": logs_q, "waveform_seg": wav_seg, - "gt_spk_emb": gt_spk_emb, - "syn_spk_emb": syn_spk_emb, + "gt_cons_emb": gt_cons_emb, + "syn_cons_emb": syn_cons_emb, "slice_ids": slice_ids, } ) @@ -1058,6 +1088,7 @@ class Vits(BaseTTS): outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p} return outputs + @torch.no_grad() def inference_voice_conversion(self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None): """Inference for voice conversion @@ -1201,9 +1232,9 @@ class Vits(BaseTTS): feats_disc_fake=feats_disc_fake, feats_disc_real=feats_disc_real, loss_duration=self.model_outputs_cache["loss_duration"], - use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, - gt_spk_emb=self.model_outputs_cache["gt_spk_emb"], - syn_spk_emb=self.model_outputs_cache["syn_spk_emb"], + use_encoder_consistency_loss=self.args.use_speaker_encoder_as_loss or self.args.use_emotion_encoder_as_loss, + gt_cons_emb=self.model_outputs_cache["gt_cons_emb"], + syn_cons_emb=self.model_outputs_cache["syn_cons_emb"], ) return self.model_outputs_cache, loss_dict @@ -1317,7 +1348,7 @@ class Vits(BaseTTS): "d_vector": d_vector, "language_id": language_id, "language_name": language_name, - "emotion_embedding": emotion_embedding, + "emotion_embedding": emotion_embedding, "emotion_ids": emotion_id } @@ -1584,9 +1615,9 @@ class Vits(BaseTTS): language_manager = LanguageManager.init_from_config(config) emotion_manager = EmotionManager.init_from_config(config) - if config.model_args.speaker_encoder_model_path: + if config.model_args.encoder_model_path: speaker_manager.init_encoder( - config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path + config.model_args.encoder_model_path, config.model_args.encoder_config_path ) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager) diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 01d54ad6..cb34c4e4 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -151,8 +151,8 @@ class ModelManager(object): output_stats_path = os.path.join(output_path, "scale_stats.npy") output_d_vector_file_path = os.path.join(output_path, "speakers.json") output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json") - speaker_encoder_config_path = os.path.join(output_path, "config_se.json") - speaker_encoder_model_path = os.path.join(output_path, "model_se.pth.tar") + encoder_config_path = os.path.join(output_path, "config_se.json") + encoder_model_path = os.path.join(output_path, "model_se.pth.tar") # update the scale_path.npy file path in the model config.json self._update_path("audio.stats_path", output_stats_path, config_path) @@ -166,10 +166,10 @@ class ModelManager(object): self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path) # update the speaker_encoder file path in the model config.json to the current path - self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path) - self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path) - self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path) - self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path) + self._update_path("encoder_model_path", encoder_model_path, config_path) + self._update_path("model_args.encoder_model_path", encoder_model_path, config_path) + self._update_path("encoder_config_path", encoder_config_path, config_path) + self._update_path("model_args.encoder_config_path", encoder_config_path, config_path) @staticmethod def _update_path(field_name, new_path, config_path): diff --git a/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py index 6ce59c6c..69b3ccd5 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py @@ -36,15 +36,20 @@ config.audio.trim_db = 60 # active multispeaker d-vec mode config.model_args.use_speaker_embedding = True config.model_args.use_d_vector_file = False -config.model_args.d_vector_file = None +config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" +config.model_args.speaker_embedding_channels = 128 config.model_args.d_vector_dim = 256 # emotion -config.model_args.use_external_emotions_embeddings = True -config.model_args.use_emotion_embedding = False +config.model_args.use_external_emotions_embeddings = False +config.model_args.use_emotion_embedding = True config.model_args.emotion_embedding_dim = 256 config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json" +# consistency loss +# config.model_args.use_emotion_encoder_as_loss = True +# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar" +# config.model_args.encoder_config_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/config_se.json" config.save_json(config_path) @@ -69,7 +74,7 @@ continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" -emotion_id = "ljspeech-1" +emotion_id = "ljspeech-3" continue_speakers_path = os.path.join(continue_path, "speakers.json") continue_emotion_path = os.path.join(continue_path, "speakers.json")