diff --git a/TTS/tts/layers/generic/classifier.py b/TTS/tts/layers/generic/classifier.py index e4f41481..b09ccdc6 100644 --- a/TTS/tts/layers/generic/classifier.py +++ b/TTS/tts/layers/generic/classifier.py @@ -50,16 +50,14 @@ class ReversalClassifier(nn.Module): return x, loss @staticmethod - def loss(labels, predictions, x_mask): - ignore_index = -100 + def loss(labels, predictions, x_mask=None, ignore_index=-100): if x_mask is None: x_mask = torch.Tensor([predictions.size(1)]).repeat(predictions.size(0)).int().to(predictions.device) - - ml = torch.max(x_mask) - input_mask = torch.arange(ml, device=predictions.device)[None, :] < x_mask[:, None] - - target = labels.repeat(ml.int().item(), 1).transpose(0, 1) + ml = torch.max(x_mask) + input_mask = torch.arange(ml, device=predictions.device)[None, :] < x_mask[:, None] + else: + input_mask = x_mask.squeeze().bool() + target = labels.repeat(input_mask.size(-1), 1).transpose(0, 1).int().long() target[~input_mask] = ignore_index - return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index e28b259b..17c2cb49 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -552,6 +552,7 @@ class VitsArgs(Coqpit): use_prosody_enc_emo_classifier: bool = False use_prosody_conditional_flow_module: bool = False + prosody_conditional_flow_module_on_decoder: bool = False detach_dp_input: bool = True use_language_embedding: bool = False @@ -1150,15 +1151,6 @@ class Vits(BaseTTS): pros_emb=pros_emb if not self.args.use_prosody_conditional_flow_module else None ) - # conditional module - if self.args.use_prosody_conditional_flow_module: - m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb) - - # reversal speaker loss to force the encoder to be speaker identity free - l_text_emotion = None - if self.args.use_text_enc_emo_classifier: - _, l_text_emotion = self.emo_text_enc_classifier(m_p.transpose(1, 2), eid, x_mask=None) - # reversal speaker loss to force the encoder to be speaker identity free l_text_speaker = None if self.args.use_text_enc_spk_reversal_classifier: @@ -1167,6 +1159,24 @@ class Vits(BaseTTS): # flow layers z_p = self.flow(z, y_mask, g=g) + # reversal speaker loss to force the encoder to be speaker identity free + l_text_emotion = None + if self.args.use_text_enc_emo_classifier: + if self.args.prosody_conditional_flow_module_on_decoder: + _, l_text_emotion = self.emo_text_enc_classifier(z_p.transpose(1, 2), eid, x_mask=y_mask) + + # conditional module + if self.args.use_prosody_conditional_flow_module: + if not self.args.prosody_conditional_flow_module_on_decoder: + m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb) + else: + z_p = self.prosody_conditional_module(z_p, y_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb) + + # reversal speaker loss to force the encoder to be speaker identity free + if self.args.use_text_enc_emo_classifier: + if not self.args.prosody_conditional_flow_module_on_decoder: + _, l_text_emotion = self.emo_text_enc_classifier(m_p.transpose(1, 2), eid, x_mask=x_mask) + # duration predictor g_dp = g if self.args.condition_dp_on_speaker else None if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): @@ -1318,7 +1328,8 @@ class Vits(BaseTTS): # conditional module if self.args.use_prosody_conditional_flow_module: - m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb) + if not self.args.prosody_conditional_flow_module_on_decoder: + m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb) # duration predictor g_dp = g if self.args.condition_dp_on_speaker else None @@ -1358,6 +1369,12 @@ class Vits(BaseTTS): logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale + + # conditional module + if self.args.use_prosody_conditional_flow_module: + if self.args.prosody_conditional_flow_module_on_decoder: + z_p = self.prosody_conditional_module(z_p, y_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb, reverse=True) + z = self.flow(z_p, y_mask, g=g, reverse=True) # upsampling if needed @@ -1394,7 +1411,7 @@ class Vits(BaseTTS): @torch.no_grad() def inference_voice_conversion( - self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None + self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None, ref_emotion=None, target_emotion=None ): """Inference for voice conversion @@ -1417,10 +1434,11 @@ class Vits(BaseTTS): speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector # print(y.shape, y_lengths.shape) - wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt) + wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt, ref_emotion, target_emotion) return wav - def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt): + + def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt, ref_emotion=None, target_emotion=None): """Forward pass for voice conversion TODO: create an end-point for voice conversion @@ -1441,13 +1459,31 @@ class Vits(BaseTTS): g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1) else: raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.") + # emotion embedding + if self.args.use_emotion_embedding and ref_emotion is not None and target_emotion is not None: + ge_src = self.emb_g(ref_emotion).unsqueeze(-1) + ge_tgt = self.emb_g(target_emotion).unsqueeze(-1) + elif self.args.use_external_emotions_embeddings and ref_emotion is not None and target_emotion is not None: + ge_src = F.normalize(ref_emotion).unsqueeze(-1) + ge_tgt = F.normalize(target_emotion).unsqueeze(-1) z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=g_src) z_p = self.flow(z, y_mask, g=g_src) + + # change the emotion + if ge_tgt is not None and ge_tgt is not None and self.args.use_prosody_conditional_flow_module: + if not self.args.prosody_conditional_flow_module_on_decoder: + ze = self.prosody_conditional_module(z_p, y_mask, g=ge_src, reverse=True) + z_p = self.prosody_conditional_module(ze, y_mask, g=ge_tgt) + else: + ze = self.prosody_conditional_module(z_p, y_mask, g=ge_src) + z_p = self.prosody_conditional_module(ze, y_mask, g=ge_tgt, reverse=True) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) return o_hat, y_mask, (z, z_p, z_hat) + def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: """Perform a single training step. Run the model forward pass and compute losses. diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index e9552d59..803c8888 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -266,6 +266,8 @@ def transfer_voice( reference_d_vector=None, do_trim_silence=False, use_griffin_lim=False, + source_emotion_feature=None, + target_emotion_feature=None, ): """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to the vocoder model. @@ -311,6 +313,12 @@ def transfer_voice( if reference_d_vector is not None: reference_d_vector = embedding_to_torch(reference_d_vector, cuda=use_cuda) + if source_emotion_feature is not None: + source_emotion_feature = embedding_to_torch(source_emotion_feature, cuda=use_cuda) + + if target_emotion_feature is not None: + target_emotion_feature = embedding_to_torch(target_emotion_feature, cuda=use_cuda) + # load reference_wav audio reference_wav = embedding_to_torch(model.ap.load_wav(reference_wav, sr=model.ap.sample_rate), cuda=use_cuda) @@ -318,7 +326,7 @@ def transfer_voice( _func = model.module.inference_voice_conversion else: _func = model.inference_voice_conversion - model_outputs = _func(reference_wav, speaker_id, d_vector, reference_speaker_id, reference_d_vector) + model_outputs = _func(reference_wav, speaker_id, d_vector, reference_speaker_id, reference_d_vector, ref_emotion=source_emotion_feature, target_emotion=target_emotion_feature) # convert outputs to numpy # plot results diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index a8c07500..4da59e46 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -214,6 +214,8 @@ class Synthesizer(object): reference_wav=None, reference_speaker_name=None, emotion_name=None, + source_emotion=None, + target_emotion=None, ) -> List[int]: """🐸 TTS magic. Run all the models and generate speech. @@ -293,7 +295,7 @@ class Synthesizer(object): # handle emotion emotion_embedding, emotion_id = None, None - if not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or ( + if not reference_wav and not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or ( getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None) )): if emotion_name and isinstance(emotion_name, str): @@ -398,6 +400,32 @@ class Synthesizer(object): reference_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip( reference_wav ) + # get the emotions embeddings + # handle emotion + source_emotion_feature, target_emotion_feature = None, None + if source_emotion is not None and target_emotion is not None and not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or ( + getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None) + )): + if source_emotion and isinstance(source_emotion, str): + if getattr(self.tts_config, "use_external_emotions_embeddings", False) or ( + getattr(self.tts_config, "model_args", None) + and getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False) + ): + # get the average emotion embedding from the saved embeddings. + source_emotion_feature = self.tts_model.emotion_manager.get_mean_embedding( + source_emotion, num_samples=None, randomize=False + ) + source_emotion_feature = np.array(source_emotion_feature)[None, :] # [1 x embedding_dim] + # target + target_emotion_feature = self.tts_model.emotion_manager.get_mean_embedding( + target_emotion, num_samples=None, randomize=False + ) + target_emotion_feature = np.array(target_emotion_feature)[None, :] # [1 x embedding_dim] + else: + # get emotion idx + source_emotion_feature = self.tts_model.emotion_manager.ids[source_emotion] + target_emotion_feature = self.tts_model.emotion_manager.ids[target_emotion] + outputs = transfer_voice( model=self.tts_model, @@ -409,6 +437,8 @@ class Synthesizer(object): use_griffin_lim=use_gl, reference_speaker_id=reference_speaker_id, reference_d_vector=reference_speaker_embedding, + source_emotion_feature=source_emotion_feature, + target_emotion_feature=target_emotion_feature, ) waveform = outputs if not use_gl: diff --git a/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py index c80e6771..e89e538b 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py @@ -49,6 +49,9 @@ config.model_args.use_text_enc_spk_reversal_classifier = False config.model_args.use_prosody_conditional_flow_module = True +config.model_args.prosody_conditional_flow_module_on_decoder = True +config.model_args.use_text_enc_emo_classifier = True + # consistency loss # config.model_args.use_emotion_encoder_as_loss = True # config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"