diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index ff8fcf12..b6892d85 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -115,7 +115,8 @@ class VitsConfig(BaseTTSConfig): mel_loss_alpha: float = 45.0 dur_loss_alpha: float = 1.0 consistency_loss_alpha: float = 1.0 - text_enc_spk_reversal_loss_alpha: float = 2.0 + speaker_classifier_loss_alpha: float = 2.0 + emotion_classifier_loss_alpha: float = 4.0 # data loader params return_wav: bool = True diff --git a/TTS/tts/layers/generic/classifier.py b/TTS/tts/layers/generic/classifier.py index 938cbdb8..e4f41481 100644 --- a/TTS/tts/layers/generic/classifier.py +++ b/TTS/tts/layers/generic/classifier.py @@ -20,7 +20,7 @@ class GradientReversalFunction(torch.autograd.Function): class ReversalClassifier(nn.Module): - """Adversarial classifier with a gradient reversal layer. + """Adversarial classifier with an optional gradient reversal layer. Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/ Args: @@ -29,20 +29,22 @@ class ReversalClassifier(nn.Module): hidden_channels (int): Number of hidden channels. gradient_clipping_bound (float): Maximal value of the gradient which flows from this module. Default: 0.25 scale_factor (float): Scale multiplier of the reversed gradientts. Default: 1.0 + reversal (bool): If True reversal the gradients. Default: True """ - def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0): + def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0, reversal=True): super().__init__() + self.reversal = reversal self._lambda = scale_factor self._clipping = gradient_clipping_bounds self._out_channels = out_channels self._classifier = nn.Sequential( nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, out_channels) ) - self.test = nn.Linear(in_channels, hidden_channels) def forward(self, x, labels, x_mask=None): - x = GradientReversalFunction.apply(x, self._lambda, self._clipping) + if self.reversal: + x = GradientReversalFunction.apply(x, self._lambda, self._clipping) x = self._classifier(x) loss = self.loss(labels, x, x_mask) return x, loss @@ -60,3 +62,4 @@ class ReversalClassifier(nn.Module): target[~input_mask] = ignore_index return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index) + diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 14c7b9b5..5f44dc22 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -590,7 +590,8 @@ class VitsGeneratorLoss(nn.Module): self.dur_loss_alpha = c.dur_loss_alpha self.mel_loss_alpha = c.mel_loss_alpha self.consistency_loss_alpha = c.consistency_loss_alpha - self.text_enc_spk_reversal_loss_alpha = c.text_enc_spk_reversal_loss_alpha + self.emotion_classifier_alpha = c.emotion_classifier_loss_alpha + self.speaker_classifier_alpha = c.speaker_classifier_loss_alpha self.stft = TorchSTFT( c.audio.fft_size, @@ -665,7 +666,9 @@ class VitsGeneratorLoss(nn.Module): gt_cons_emb=None, syn_cons_emb=None, loss_prosody_enc_spk_rev_classifier=None, + loss_prosody_enc_emo_classifier=None, loss_text_enc_spk_rev_classifier=None, + loss_text_enc_emo_classifier=None, ): """ Shapes: @@ -702,14 +705,27 @@ class VitsGeneratorLoss(nn.Module): return_dict["loss_consistency_enc"] = loss_enc if loss_prosody_enc_spk_rev_classifier is not None: + loss_prosody_enc_spk_rev_classifier = loss_prosody_enc_spk_rev_classifier * self.speaker_classifier_alpha loss += loss_prosody_enc_spk_rev_classifier return_dict["loss_prosody_enc_spk_rev_classifier"] = loss_prosody_enc_spk_rev_classifier + if loss_prosody_enc_emo_classifier is not None: + loss_prosody_enc_emo_classifier = loss_prosody_enc_emo_classifier * self.emotion_classifier_alpha + loss += loss_prosody_enc_emo_classifier + return_dict["loss_prosody_enc_emo_classifier"] = loss_prosody_enc_emo_classifier + + if loss_text_enc_spk_rev_classifier is not None: - loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.text_enc_spk_reversal_loss_alpha + loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.speaker_classifier_alpha loss += loss_text_enc_spk_rev_classifier return_dict["loss_text_enc_spk_rev_classifier"] = loss_text_enc_spk_rev_classifier + if loss_text_enc_emo_classifier is not None: + loss_text_enc_emo_classifier = loss_text_enc_emo_classifier * self.emotion_classifier_alpha + loss += loss_text_enc_emo_classifier + return_dict["loss_text_enc_emo_classifier"] = loss_text_enc_emo_classifier + + # pass losses to the dict return_dict["loss_gen"] = loss_gen return_dict["loss_kl"] = loss_kl diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index da65c516..e28b259b 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -541,6 +541,7 @@ class VitsArgs(Coqpit): emotion_embedding_dim: int = 0 num_emotions: int = 0 use_text_enc_spk_reversal_classifier: bool = False + use_text_enc_emo_classifier: bool = False # prosody encoder use_prosody_encoder: bool = False @@ -548,6 +549,7 @@ class VitsArgs(Coqpit): prosody_encoder_num_heads: int = 1 prosody_encoder_num_tokens: int = 5 use_prosody_enc_spk_reversal_classifier: bool = False + use_prosody_enc_emo_classifier: bool = False use_prosody_conditional_flow_module: bool = False @@ -706,6 +708,13 @@ class Vits(BaseTTS): out_channels=self.num_speakers, hidden_channels=256, ) + if self.args.use_prosody_enc_emo_classifier: + self.pros_enc_emotion_classifier = ReversalClassifier( + in_channels=self.args.prosody_embedding_dim, + out_channels=self.num_emotions, + hidden_channels=256, + reversal=False + ) if self.args.use_prosody_conditional_flow_module: cond_embedding_dim = 0 @@ -732,6 +741,14 @@ class Vits(BaseTTS): hidden_channels=256, ) + if self.args.use_text_enc_emo_classifier: + self.emo_text_enc_classifier = ReversalClassifier( + in_channels=self.args.hidden_channels, + out_channels=self.num_emotions, + hidden_channels=256, + reversal=False + ) + self.waveform_decoder = HifiganGenerator( self.args.hidden_channels, 1, @@ -1117,11 +1134,14 @@ class Vits(BaseTTS): # prosody embedding pros_emb = None l_pros_speaker = None + l_pros_emotion = None if self.args.use_prosody_encoder: pros_emb = self.prosody_encoder(z).transpose(1, 2) if self.args.use_prosody_enc_spk_reversal_classifier: _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) - + if self.args.use_prosody_enc_emo_classifier: + _, l_pros_emotion = self.pros_enc_emotion_classifier(pros_emb.transpose(1, 2), eid, x_mask=None) + x, m_p, logs_p, x_mask = self.text_encoder( x, x_lengths, @@ -1134,6 +1154,11 @@ class Vits(BaseTTS): if self.args.use_prosody_conditional_flow_module: m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb) + # reversal speaker loss to force the encoder to be speaker identity free + l_text_emotion = None + if self.args.use_text_enc_emo_classifier: + _, l_text_emotion = self.emo_text_enc_classifier(m_p.transpose(1, 2), eid, x_mask=None) + # reversal speaker loss to force the encoder to be speaker identity free l_text_speaker = None if self.args.use_text_enc_spk_reversal_classifier: @@ -1214,7 +1239,9 @@ class Vits(BaseTTS): "syn_cons_emb": syn_cons_emb, "slice_ids": slice_ids, "loss_prosody_enc_spk_rev_classifier": l_pros_speaker, + "loss_prosody_enc_emo_classifier": l_pros_emotion, "loss_text_enc_spk_rev_classifier": l_text_speaker, + "loss_text_enc_emo_classifier": l_text_emotion, } ) return outputs @@ -1531,7 +1558,9 @@ class Vits(BaseTTS): gt_cons_emb=self.model_outputs_cache["gt_cons_emb"], syn_cons_emb=self.model_outputs_cache["syn_cons_emb"], loss_prosody_enc_spk_rev_classifier=self.model_outputs_cache["loss_prosody_enc_spk_rev_classifier"], + loss_prosody_enc_emo_classifier=self.model_outputs_cache["loss_prosody_enc_emo_classifier"], loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"], + loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"], ) return self.model_outputs_cache, loss_dict @@ -1737,7 +1766,7 @@ class Vits(BaseTTS): emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]] emotion_embeddings = torch.FloatTensor(emotion_embeddings) - if self.emotion_manager is not None and self.emotion_manager.embeddings and self.args.use_emotion_embedding: + if self.emotion_manager is not None and self.emotion_manager.embeddings and (self.args.use_emotion_embedding or self.args.use_prosody_enc_emo_classifier or self.args.use_text_enc_emo_classifier): emotion_mapping = self.emotion_manager.embeddings emotion_names = [emotion_mapping[w]["name"] for w in batch["audio_files"]] emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names] diff --git a/TTS/tts/utils/emotions.py b/TTS/tts/utils/emotions.py index bf5646a9..9db1aaab 100644 --- a/TTS/tts/utils/emotions.py +++ b/TTS/tts/utils/emotions.py @@ -94,7 +94,7 @@ class EmotionManager(EmbeddingManager): EmotionEncoder: Emotion encoder object. """ emotion_manager = None - if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False): + if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False): if get_from_config_or_model_args_with_default(config, "emotions_ids_file", None): emotion_manager = EmotionManager( emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None) @@ -106,7 +106,7 @@ class EmotionManager(EmbeddingManager): ) ) - if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False): + if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False): if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None): emotion_manager = EmotionManager( embeddings_file_path=get_from_config_or_model_args_with_default( diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 333ea46a..a8c07500 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -293,9 +293,9 @@ class Synthesizer(object): # handle emotion emotion_embedding, emotion_id = None, None - if self.tts_emotions_file or ( + if not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or ( getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None) - ): + )): if emotion_name and isinstance(emotion_name, str): if getattr(self.tts_config, "use_external_emotions_embeddings", False) or ( getattr(self.tts_config, "model_args", None) diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py index ccd48616..e6f94059 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -43,6 +43,10 @@ config.model_args.d_vector_dim = 128 # prosody embedding config.model_args.use_prosody_encoder = True config.model_args.prosody_embedding_dim = 64 +# active classifier +config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json" +config.model_args.use_prosody_enc_emo_classifier = True +config.model_args.use_text_enc_emo_classifier = True config.save_json(config_path)