diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index bdc83d5e..c578a301 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -178,7 +178,7 @@ If you don't specify any models, then it uses LJSpeech based English model. help="wav file(s) to condition a multi-speaker TTS model with a Speaker Encoder. You can give multiple file paths. The d_vectors is computed as their average.", default=None, ) - parser.add_argument("--gst_style", help="Wav path file for GST stylereference.", default=None) + parser.add_argument("--gst_style", help="Wav path file for GST style reference.", default=None) parser.add_argument( "--list_speaker_idxs", help="List available speaker ids for the defined multi-speaker model.", @@ -317,6 +317,7 @@ If you don't specify any models, then it uses LJSpeech based English model. args.speaker_idx, args.language_idx, args.speaker_wav, + style_wav=args.gst_style, reference_wav=args.reference_wav, reference_speaker_name=args.reference_speaker_idx, emotion_name=args.emotion_idx, diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index 64017a21..bb2a823e 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -117,7 +117,7 @@ def load_tts_samples( if eval_split: if meta_file_val: meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) - meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval] + meta_data_eval = [{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval] else: meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size) meta_data_eval_all += meta_data_eval diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 8341f5bb..cb2d7548 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -189,7 +189,7 @@ class Tacotron(BaseTacotron): encoder_outputs = self.encoder(inputs) if self.gst and self.use_gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"]) + encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_feature"], aux_input["d_vectors"]) if self.num_speakers > 1: if not self.use_d_vector_file: # B x 1 x speaker_embed_dim diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index d4e665e3..58aeecb2 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -215,7 +215,7 @@ class Tacotron2(BaseTacotron): if self.gst and self.use_gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"]) + encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_feature"], aux_input["d_vectors"]) if self.num_speakers > 1: if not self.use_d_vector_file: diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index be10cad1..8a177f91 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -508,6 +508,8 @@ class VitsArgs(Coqpit): # prosody encoder use_prosody_encoder: bool = False prosody_embedding_dim: int = 0 + prosody_encoder_num_heads: int = 1 + prosody_encoder_num_tokens: int = 5 detach_dp_input: bool = True use_language_embedding: bool = False @@ -642,8 +644,8 @@ class Vits(BaseTTS): if self.args.use_prosody_encoder: self.prosody_encoder = GST( num_mel=self.args.hidden_channels, - num_heads=1, - num_style_tokens=5, + num_heads=self.args.prosody_encoder_num_heads, + num_style_tokens=self.args.prosody_encoder_num_tokens, gst_embedding_dim=self.args.prosody_embedding_dim, ) self.speaker_reversal_classifier = ReversalClassifier( @@ -840,7 +842,7 @@ class Vits(BaseTTS): @staticmethod def _set_cond_input(aux_input: Dict): """Set the speaker conditioning input based on the multi-speaker mode.""" - sid, g, lid, eid, eg = None, None, None, None, None + sid, g, lid, eid, eg, pf = None, None, None, None, None, None if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: sid = aux_input["speaker_ids"] if sid.ndim == 0: @@ -865,7 +867,11 @@ class Vits(BaseTTS): if eg.ndim == 2: eg = eg.unsqueeze_(0) - return sid, g, lid, eid, eg + if "style_feature" in aux_input and aux_input["style_feature"] is not None: + pf = aux_input["style_feature"] + if pf.ndim == 2: + pf = pf.unsqueeze_(0) + return sid, g, lid, eid, eg, pf def _set_speaker_input(self, aux_input: Dict): d_vectors = aux_input.get("d_vectors", None) @@ -968,7 +974,7 @@ class Vits(BaseTTS): - syn_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]` """ outputs = {} - sid, g, lid, eid, eg = self._set_cond_input(aux_input) + sid, g, lid, eid, eg, _ = self._set_cond_input(aux_input) # speaker embedding if self.args.use_speaker_embedding and sid is not None: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] @@ -998,14 +1004,14 @@ class Vits(BaseTTS): if self.args.use_prosody_encoder: pros_emb = self.prosody_encoder(z).transpose(1, 2) _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) - + # print("Encoder input", x.shape) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb) - + # print("X shape:", x.shape, "m_p shape:", m_p.shape, "x_mask:", x_mask.shape, "x_lengths:", x_lengths.shape) # flow layers z_p = self.flow(z, y_mask, g=g) - + # print("Y mask:", y_mask.shape) # duration predictor - g_dp = g + g_dp = g if self.args.condition_dp_on_speaker else None if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder: if g_dp is None: g_dp = eg @@ -1092,6 +1098,7 @@ class Vits(BaseTTS): "language_ids": None, "emotion_embeddings": None, "emotion_ids": None, + "style_feature": None, }, ): # pylint: disable=dangerous-default-value """ @@ -1112,7 +1119,7 @@ class Vits(BaseTTS): - m_p: :math:`[B, C, T_dec]` - logs_p: :math:`[B, C, T_dec]` """ - sid, g, lid, eid, eg = self._set_cond_input(aux_input) + sid, g, lid, eid, eg, pf = self._set_cond_input(aux_input) x_lengths = self._set_x_lengths(x, aux_input) # speaker embedding @@ -1135,29 +1142,42 @@ class Vits(BaseTTS): if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) - x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg) + # prosody embedding + pros_emb = None + if self.args.use_prosody_encoder: + # extract posterior encoder feature + pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device) + z_pro, _, _, _ = self.posterior_encoder(pf, pf_lengths, g=g) + pros_emb = self.prosody_encoder(z_pro).transpose(1, 2) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb) # duration predictor + g_dp = g if self.args.condition_dp_on_speaker else None if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder: - if g is None: + if g_dp is None: g_dp = eg else: - g_dp = torch.cat([g, eg], dim=1) # [b, h1+h2, 1] - else: - g_dp = g + g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1] + + if self.args.use_prosody_encoder: + if g_dp is None: + g_dp = pros_emb + else: + g_dp = torch.cat([g_dp, pros_emb], dim=1) # [b, h1+h2, 1] if self.args.use_sdp: logw = self.duration_predictor( x, x_mask, - g=g_dp if self.args.condition_dp_on_speaker else None, + g=g_dp, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb, ) else: logw = self.duration_predictor( - x, x_mask, g=g_dp if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb + x, x_mask, g=g_dp, lang_emb=lang_emb ) w = torch.exp(logw) * x_mask * self.length_scale @@ -1175,9 +1195,22 @@ class Vits(BaseTTS): z = self.flow(z_p, y_mask, g=g, reverse=True) o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) - outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p} + outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "durations": w_ceil} return outputs + def compute_style_feature(self, style_wav_path): + style_wav, sr = torchaudio.load(style_wav_path) + if sr != self.config.audio.sample_rate: + raise RuntimeError(" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!") + y = wav_to_spec( + style_wav, + self.config.audio.fft_size, + self.config.audio.hop_length, + self.config.audio.win_length, + center=False, + ) + return y + @torch.no_grad() def inference_voice_conversion( self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 4208b6aa..ac92d345 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -14,18 +14,18 @@ def numpy_to_torch(np_array, dtype, cuda=False): return tensor -def compute_style_mel(style_wav, ap, cuda=False): - style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) +def compute_style_feature(style_wav, ap, cuda=False): + style_feature = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) if cuda: - return style_mel.cuda() - return style_mel + return style_feature.cuda() + return style_feature def run_model_torch( model: nn.Module, inputs: torch.Tensor, speaker_id: int = None, - style_mel: torch.Tensor = None, + style_feature: torch.Tensor = None, d_vector: torch.Tensor = None, language_id: torch.Tensor = None, emotion_id: torch.Tensor = None, @@ -37,7 +37,7 @@ def run_model_torch( model (nn.Module): The model to run inference. inputs (torch.Tensor): Input tensor with character ids. speaker_id (int, optional): Input speaker ids for multi-speaker models. Defaults to None. - style_mel (torch.Tensor, optional): Spectrograms used for voice styling . Defaults to None. + style_feature (torch.Tensor, optional): Spectrograms used for voice styling . Defaults to None. d_vector (torch.Tensor, optional): d-vector for multi-speaker models . Defaults to None. Returns: @@ -54,7 +54,7 @@ def run_model_torch( "x_lengths": input_lengths, "speaker_ids": speaker_id, "d_vectors": d_vector, - "style_mel": style_mel, + "style_feature": style_feature, "language_ids": language_id, "emotion_ids": emotion_id, "emotion_embeddings": emotion_embedding, @@ -161,12 +161,17 @@ def synthesis( Language ID passed to the language embedding layer in multi-langual model. Defaults to None. """ # GST processing - style_mel = None - if CONFIG.has("gst") and CONFIG.gst and style_wav is not None: - if isinstance(style_wav, dict): - style_mel = style_wav - else: - style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) + style_feature = None + if style_wav is not None: + if CONFIG.has("gst") and CONFIG.gst: + if isinstance(style_wav, dict): + style_feature = style_wav + else: + style_feature = compute_style_feature(style_wav, model.ap, cuda=use_cuda) + if hasattr(model, 'compute_style_feature'): + style_feature = model.compute_style_feature(style_wav) + + # convert text to sequence of token IDs text_inputs = np.asarray( model.tokenizer.text_to_ids(text, language=language_id), @@ -188,8 +193,8 @@ def synthesis( if emotion_embedding is not None: emotion_embedding = embedding_to_torch(emotion_embedding, cuda=use_cuda) - if not isinstance(style_mel, dict): - style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) + if not isinstance(style_feature, dict): + style_feature = numpy_to_torch(style_feature, torch.float, cuda=use_cuda) text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) text_inputs = text_inputs.unsqueeze(0) # synthesize voice @@ -197,7 +202,7 @@ def synthesis( model, text_inputs, speaker_id, - style_mel, + style_feature, d_vector=d_vector, language_id=language_id, emotion_id=emotion_id, diff --git a/tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py b/tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py index 19a781dd..e5046826 100644 --- a/tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py +++ b/tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py @@ -47,7 +47,7 @@ config.model_args.use_emotion_embedding = False config.model_args.emotion_embedding_dim = 256 config.model_args.emotion_just_encoder = True config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json" - +config.use_style_weighted_sampler = True # consistency loss # config.model_args.use_emotion_encoder_as_loss = True # config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar" @@ -64,6 +64,13 @@ command_train = ( "--coqpit.datasets.0.meta_file_val metadata.csv " "--coqpit.datasets.0.path tests/data/ljspeech " "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.datasets.0.speech_style style1 " + "--coqpit.datasets.1.name ljspeech_test " + "--coqpit.datasets.1.meta_file_train metadata.csv " + "--coqpit.datasets.1.meta_file_val metadata.csv " + "--coqpit.datasets.1.path tests/data/ljspeech " + "--coqpit.datasets.1.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.datasets.1.speech_style style2 " "--coqpit.test_delay_epochs 0" ) run_cli(command_train) diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py index 7a1cd6ef..6ff4412b 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -26,7 +26,7 @@ config = VitsConfig( print_step=1, print_eval=True, test_sentences=[ - ["Be a voice, not an echo.", "ljspeech-1", None, None, "ljspeech-1"], + ["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None], ], ) # set audio config @@ -38,11 +38,11 @@ config.model_args.use_speaker_embedding = True config.model_args.use_d_vector_file = False config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" config.model_args.speaker_embedding_channels = 128 -config.model_args.d_vector_dim = 256 +config.model_args.d_vector_dim = 128 # prosody embedding config.model_args.use_prosody_encoder = True -config.model_args.prosody_embedding_dim = 256 +config.model_args.prosody_embedding_dim = 64 config.save_json(config_path) @@ -67,12 +67,11 @@ continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" -emotion_id = "ljspeech-3" +style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav" continue_speakers_path = os.path.join(continue_path, "speakers.json") -continue_emotion_path = os.path.join(continue_path, "speakers.json") -inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --emotion_idx {emotion_id} --speakers_file_path {continue_speakers_path} --emotions_file_path {continue_emotion_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --gst_style {style_wav_path}" run_cli(inference_command) # restore the model and continue training for one more epoch