From bd35371944cb14da696536e6f961c2be9bd7b633 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 27 May 2022 16:00:41 -0300 Subject: [PATCH] Add prosody encoder inference support --- TTS/bin/synthesize.py | 2 + TTS/tts/models/vits.py | 49 ++++++++++++++++--- TTS/tts/utils/synthesis.py | 15 ++++++ TTS/utils/synthesizer.py | 18 +++++++ ...t_vits_speaker_emb_with_prosody_encoder.py | 8 +-- 5 files changed, 81 insertions(+), 11 deletions(-) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 2f32ec96..ee44731a 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -179,6 +179,7 @@ If you don't specify any models, then it uses LJSpeech based English model. default=None, ) parser.add_argument("--style_wav", type=str, help="Wav path file for prosody reference.", default=None) + parser.add_argument("--style_speaker_name", type=str, help="The speaker name from the style_wav. If not provide the speaker embedding will be computed using the speaker encoder.", default=None) parser.add_argument( "--capacitron_style_text", type=str, help="Transcription of the style_wav reference.", default=None ) @@ -325,6 +326,7 @@ If you don't specify any models, then it uses LJSpeech based English model. style_text=args.capacitron_style_text, reference_speaker_name=args.reference_speaker_idx, emotion_name=args.emotion_idx, + style_speaker_name=args.style_speaker_name, ) # save the results diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 7ea87ebc..41808866 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -981,7 +981,7 @@ class Vits(BaseTTS): @staticmethod def _set_cond_input(aux_input: Dict): """Set the speaker conditioning input based on the multi-speaker mode.""" - sid, g, lid, eid, eg, pf = None, None, None, None, None, None + sid, g, lid, eid, eg, pf, ssid, ssg = None, None, None, None, None, None, None, None if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: sid = aux_input["speaker_ids"] if sid.ndim == 0: @@ -1010,7 +1010,18 @@ class Vits(BaseTTS): pf = aux_input["style_feature"] if pf.ndim == 2: pf = pf.unsqueeze_(0) - return sid, g, lid, eid, eg, pf + + if "style_speaker_id" in aux_input and aux_input["style_speaker_id"] is not None: + ssid = aux_input["style_speaker_id"] + if ssid.ndim == 0: + ssid = ssid.unsqueeze_(0) + + if "style_speaker_d_vector" in aux_input and aux_input["style_speaker_d_vector"] is not None: + ssg = F.normalize(aux_input["style_speaker_d_vector"]).unsqueeze(-1) + if ssg.ndim == 2: + ssg = ssg.unsqueeze_(0) + + return sid, g, lid, eid, eg, pf, ssid, ssg def _set_speaker_input(self, aux_input: Dict): d_vectors = aux_input.get("d_vectors", None) @@ -1130,7 +1141,7 @@ class Vits(BaseTTS): - syn_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]` """ outputs = {} - sid, g, lid, eid, eg, _ = self._set_cond_input(aux_input) + sid, g, lid, eid, eg, _, _, _ = self._set_cond_input(aux_input) # speaker embedding if self.args.use_speaker_embedding and sid is not None: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] @@ -1317,7 +1328,7 @@ class Vits(BaseTTS): - m_p: :math:`[B, C, T_dec]` - logs_p: :math:`[B, C, T_dec]` """ - sid, g, lid, eid, eg, pf = self._set_cond_input(aux_input) + sid, g, lid, eid, eg, pf, ssid, ssg = self._set_cond_input(aux_input) x_lengths = self._set_x_lengths(x, aux_input) # speaker embedding @@ -1336,13 +1347,17 @@ class Vits(BaseTTS): # prosody embedding pros_emb = None if self.args.use_prosody_encoder: + # speaker embedding for the style speaker + if self.args.use_speaker_embedding and ssid is not None: + ssg = self.emb_g(ssid).unsqueeze(-1) + # extract posterior encoder feature pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device) - z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g) + z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=ssg) if not self.args.use_prosody_encoder_z_p_input: pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths) else: - z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g) + z_p_inf = self.flow(z_pro, z_pro_y_mask, g=ssg) pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths) pros_emb = pros_emb.transpose(1, 2) @@ -1687,7 +1702,7 @@ class Vits(BaseTTS): config = self.config # extract speaker and language info - text, speaker_name, style_wav, language_name, emotion_name = None, None, None, None, None + text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = None, None, None, None, None, None if isinstance(sentence_info, list): if len(sentence_info) == 1: @@ -1700,23 +1715,37 @@ class Vits(BaseTTS): text, speaker_name, style_wav, language_name = sentence_info elif len(sentence_info) == 5: text, speaker_name, style_wav, language_name, emotion_name = sentence_info + elif len(sentence_info) == 6: + text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = sentence_info else: text = sentence_info + if style_wav and style_speaker_name is None: + raise RuntimeError( + f" [!] You must to provide the style_speaker_name for the style_wav !!" + ) + # get speaker id/d_vector - speaker_id, d_vector, language_id, emotion_id, emotion_embedding = None, None, None, None, None + speaker_id, d_vector, language_id, emotion_id, emotion_embedding, style_speaker_id, style_speaker_d_vector = None, None, None, None, None, None, None if hasattr(self, "speaker_manager"): if config.use_d_vector_file: if speaker_name is None: d_vector = self.speaker_manager.get_random_embeddings() else: d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) + + if style_wav is not None: + style_speaker_d_vector = self.speaker_manager.get_mean_embedding(style_speaker_name, num_samples=None, randomize=False) + elif config.use_speaker_embedding: if speaker_name is None: speaker_id = self.speaker_manager.get_random_id() else: speaker_id = self.speaker_manager.ids[speaker_name] + if style_wav is not None: + style_speaker_id = self.speaker_manager.ids[style_speaker_name] + # get language id if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: language_id = self.language_manager.ids[language_name] @@ -1740,6 +1769,8 @@ class Vits(BaseTTS): "text": text, "speaker_id": speaker_id, "style_wav": style_wav, + "style_speaker_id": style_speaker_id, + "style_speaker_d_vector": style_speaker_d_vector, "d_vector": d_vector, "language_id": language_id, "language_name": language_name, @@ -1773,6 +1804,8 @@ class Vits(BaseTTS): language_id=aux_inputs["language_id"], emotion_embedding=aux_inputs["emotion_embedding"], emotion_id=aux_inputs["emotion_ids"], + style_speaker_id=aux_inputs["style_speaker_id"], + style_speaker_d_vector=aux_inputs["style_speaker_d_vector"], use_griffin_lim=True, do_trim_silence=False, ).values() diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 803c8888..82ce79c6 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -31,6 +31,8 @@ def run_model_torch( language_id: torch.Tensor = None, emotion_id: torch.Tensor = None, emotion_embedding: torch.Tensor = None, + style_speaker_id: torch.Tensor = None, + style_speaker_d_vector: torch.Tensor = None, ) -> Dict: """Run a torch model for inference. It does not support batch inference. @@ -60,6 +62,8 @@ def run_model_torch( "language_ids": language_id, "emotion_ids": emotion_id, "emotion_embeddings": emotion_embedding, + "style_speaker_id": style_speaker_id, + "style_speaker_d_vector": style_speaker_d_vector, }, ) return outputs @@ -128,6 +132,8 @@ def synthesis( language_id=None, emotion_id=None, emotion_embedding=None, + style_speaker_id=None, + style_speaker_d_vector=None, ): """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to the vocoder model. @@ -205,6 +211,13 @@ def synthesis( if emotion_embedding is not None: emotion_embedding = embedding_to_torch(emotion_embedding, cuda=use_cuda) + if style_speaker_id is not None: + style_speaker_id = id_to_torch(style_speaker_id, cuda=use_cuda) + + if style_speaker_d_vector is not None: + style_speaker_d_vector = embedding_to_torch(style_speaker_d_vector, cuda=use_cuda) + + if not isinstance(style_feature, dict): # GST or Capacitron style mel style_feature = numpy_to_torch(style_feature, torch.float, cuda=use_cuda) @@ -229,6 +242,8 @@ def synthesis( language_id=language_id, emotion_id=emotion_id, emotion_embedding=emotion_embedding, + style_speaker_id=style_speaker_id, + style_speaker_d_vector=style_speaker_d_vector, ) model_outputs = outputs["model_outputs"] model_outputs = model_outputs[0].data.cpu().numpy() diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 4da59e46..3e328506 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -216,6 +216,7 @@ class Synthesizer(object): emotion_name=None, source_emotion=None, target_emotion=None, + style_speaker_name=None, ) -> List[int]: """🐸 TTS magic. Run all the models and generate speech. @@ -247,6 +248,8 @@ class Synthesizer(object): # handle multi-speaker speaker_embedding = None speaker_id = None + style_speaker_id = None + style_speaker_embedding = None if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"): if speaker_name and isinstance(speaker_name, str): if self.tts_config.use_d_vector_file: @@ -255,10 +258,20 @@ class Synthesizer(object): speaker_name, num_samples=None, randomize=False ) speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim] + + if style_speaker_name is not None: + style_speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding( + style_speaker_name, num_samples=None, randomize=False + ) + style_speaker_embedding = np.array(style_speaker_embedding)[None, :] # [1 x embedding_dim] + else: # get speaker idx from the speaker name speaker_id = self.tts_model.speaker_manager.ids[speaker_name] + if style_speaker_name is not None: + style_speaker_id = self.tts_model.speaker_manager.ids[style_speaker_name] + elif not speaker_name and not speaker_wav: raise ValueError( " [!] Look like you use a multi-speaker model. " @@ -327,6 +340,9 @@ class Synthesizer(object): if speaker_wav is not None: speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav) + if style_wav is not None and style_speaker_name is None: + style_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(style_wav) + use_gl = self.vocoder_model is None if not reference_wav: @@ -340,6 +356,8 @@ class Synthesizer(object): speaker_id=speaker_id, style_wav=style_wav, style_text=style_text, + style_speaker_id=style_speaker_id, + style_speaker_d_vector=style_speaker_embedding, use_griffin_lim=use_gl, d_vector=speaker_embedding, language_id=language_id, diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py index e7cad601..3737a473 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -26,7 +26,7 @@ config = VitsConfig( print_step=1, print_eval=True, test_sentences=[ - ["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None], + ["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None, "ljspeech-2"], ], ) # set audio config @@ -46,7 +46,7 @@ config.model_args.prosody_embedding_dim = 64 # active classifier config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json" config.model_args.use_prosody_enc_emo_classifier = False -config.model_args.use_text_enc_emo_classifier = True +config.model_args.use_text_enc_emo_classifier = False config.model_args.use_prosody_encoder_z_p_input = True config.model_args.prosody_encoder_type = "vae" @@ -75,11 +75,13 @@ continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" +style_speaker_name = "ljspeech-2" style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav" continue_speakers_path = os.path.join(continue_path, "speakers.json") -inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --style_wav {style_wav_path}" +print("Testing inference !") +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --style_wav {style_wav_path} --style_speaker_name {style_speaker_name}" run_cli(inference_command) # restore the model and continue training for one more epoch