diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 2e611ac8..4398c960 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -417,17 +417,17 @@ def esd(root_path, meta_files, ignored_speakers=None): if speaker_id in ignored_speakers: continue - with open(meta_file, "r", encoding="latin-1") as file_text: + with open(meta_file, "r", encoding="utf-8") as file_text: try: metadata = file_text.readlines() except Exception as e: print(f"The file {meta_file} break the import with the following error: ") raise e - for data in metadata: # this dataset have problems with csv separator, some files use just space others \t data = data.replace("\n", "").replace("\t", " ") if not data: + print(meta_file, data) continue splits = data.split(" ") @@ -435,10 +435,12 @@ def esd(root_path, meta_files, ignored_speakers=None): emotion_id = splits[-1] # all except the first and last position is the sentence text = " ".join(splits[1:-1]) + for split in meta_files: wav_file = os.path.join(root_path, speaker_id, emotion_id, split, file_id + ".wav") if os.path.exists(wav_file): items.append({"text": text, "audio_file": wav_file, "speaker_name": "ESD_" + speaker_id}) + return items diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index e669d589..cf3cd14b 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -39,6 +39,7 @@ class TextEncoder(nn.Module): dropout_p: float, language_emb_dim: int = None, emotion_emb_dim: int = None, + prosody_emb_dim: int = None, ): """Text Encoder for VITS model. @@ -66,6 +67,9 @@ class TextEncoder(nn.Module): if emotion_emb_dim: hidden_channels += emotion_emb_dim + if prosody_emb_dim: + hidden_channels += prosody_emb_dim + self.encoder = RelativePositionTransformer( in_channels=hidden_channels, out_channels=hidden_channels, @@ -81,7 +85,7 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, x, x_lengths, lang_emb=None, emo_emb=None): + def forward(self, x, x_lengths, lang_emb=None, emo_emb=None, pros_emb=None): """ Shapes: - x: :math:`[B, T]` @@ -98,6 +102,9 @@ class TextEncoder(nn.Module): if emo_emb is not None: x = torch.cat((x, emo_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) + if pros_emb is not None: + x = torch.cat((x, pros_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) + x = torch.transpose(x, 1, -1) # [b, h, t] x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t] diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 326897e4..78750f4c 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -33,6 +33,8 @@ from TTS.tts.utils.visual import plot_alignment from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results +from TTS.tts.layers.tacotron.gst_layers import GST + ############################## # IO / Feature extraction ############################## @@ -538,6 +540,11 @@ class VitsArgs(Coqpit): external_emotions_embs_file: str = None emotion_embedding_dim: int = 0 num_emotions: int = 0 + emotion_just_encoder: bool = False + + # prosody encoder + use_prosody_encoder: bool = False + prosody_embedding_dim: int = 0 detach_dp_input: bool = True use_language_embedding: bool = False @@ -624,6 +631,7 @@ class Vits(BaseTTS): self.args.dropout_p_text_encoder, language_emb_dim=self.embedded_language_dim, emotion_emb_dim=self.args.emotion_embedding_dim, + prosody_emb_dim=self.args.prosody_embedding_dim, ) self.posterior_encoder = PosteriorEncoder( @@ -645,26 +653,42 @@ class Vits(BaseTTS): cond_channels=self.cond_embedding_dim, ) + dp_cond_embedding_dim = self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0 + + if self.args.emotion_just_encoder and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): + dp_cond_embedding_dim += self.args.emotion_embedding_dim + + if self.args.use_prosody_encoder: + dp_cond_embedding_dim += self.args.prosody_embedding_dim + if self.args.use_sdp: self.duration_predictor = StochasticDurationPredictor( - self.args.hidden_channels + self.args.emotion_embedding_dim, + self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim, 192, 3, self.args.dropout_p_duration_predictor, 4, - cond_channels=self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0, + cond_channels=dp_cond_embedding_dim, language_emb_dim=self.embedded_language_dim, ) else: self.duration_predictor = DurationPredictor( - self.args.hidden_channels + self.args.emotion_embedding_dim, + self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim, 256, 3, self.args.dropout_p_duration_predictor, - cond_channels=self.cond_embedding_dim, + cond_channels=dp_cond_embedding_dim, language_emb_dim=self.embedded_language_dim, ) + if self.args.use_prosody_encoder: + self.prosody_encoder = GST( + num_mel=self.args.hidden_channels, + num_heads=1, + num_style_tokens=5, + gst_embedding_dim=self.args.prosody_embedding_dim, + ) + self.waveform_decoder = HifiganGenerator( self.args.hidden_channels, 1, @@ -840,10 +864,12 @@ class Vits(BaseTTS): if self.num_emotions > 0: print(" > initialization of emotion-embedding layers.") self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim) - self.cond_embedding_dim += self.args.emotion_embedding_dim + if not self.args.emotion_just_encoder: + self.cond_embedding_dim += self.args.emotion_embedding_dim if self.args.use_external_emotions_embeddings: - self.cond_embedding_dim += self.args.emotion_embedding_dim + if not self.args.emotion_just_encoder: + self.cond_embedding_dim += self.args.emotion_embedding_dim def get_aux_input(self, aux_input: Dict): sid, g, lid, eid, eg = self._set_cond_input(aux_input) @@ -1039,7 +1065,7 @@ class Vits(BaseTTS): eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] # concat the emotion embedding and speaker embedding - if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): + if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder: if g is None: g = eg else: @@ -1050,16 +1076,34 @@ class Vits(BaseTTS): if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) - x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg) - # posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) + # prosody embedding + pros_emb = None + if self.args.use_prosody_encoder: + pros_emb = self.prosody_encoder(z).transpose(1, 2) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb) + # flow layers z_p = self.flow(z, y_mask, g=g) # duration predictor - outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) + g_dp = g + if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder: + if g_dp is None: + g_dp = eg + else: + g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1] + + if self.args.use_prosody_encoder: + if g_dp is None: + g_dp = pros_emb + else: + g_dp = torch.cat([g_dp, pros_emb], dim=1) # [b, h1+h2, 1] + + outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb) # expand prior m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) @@ -1169,7 +1213,7 @@ class Vits(BaseTTS): eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] # concat the emotion embedding and speaker embedding - if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): + if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder: if g is None: g = eg else: @@ -1182,18 +1226,27 @@ class Vits(BaseTTS): x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg) + # duration predictor + if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder: + if g is None: + g_dp = eg + else: + g_dp = torch.cat([g, eg], dim=1) # [b, h1+h2, 1] + else: + g_dp = g + if self.args.use_sdp: logw = self.duration_predictor( x, x_mask, - g=g if self.args.condition_dp_on_speaker else None, + g=g_dp if self.args.condition_dp_on_speaker else None, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb, ) else: logw = self.duration_predictor( - x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb + x, x_mask, g=g_dp if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb ) w = torch.exp(logw) * x_mask * self.length_scale diff --git a/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py index 69b3ccd5..f200c806 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py @@ -43,6 +43,7 @@ config.model_args.d_vector_dim = 256 # emotion config.model_args.use_external_emotions_embeddings = False config.model_args.use_emotion_embedding = True +config.model_args.emotion_just_encoder = False config.model_args.emotion_embedding_dim = 256 config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json" diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py new file mode 100644 index 00000000..7a1cd6ef --- /dev/null +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -0,0 +1,81 @@ +import glob +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.vits_config import VitsConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = VitsConfig( + batch_size=2, + eval_batch_size=2, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + ["Be a voice, not an echo.", "ljspeech-1", None, None, "ljspeech-1"], + ], +) +# set audio config +config.audio.do_trim_silence = True +config.audio.trim_db = 60 + +# active multispeaker d-vec mode +config.model_args.use_speaker_embedding = True +config.model_args.use_d_vector_file = False +config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" +config.model_args.speaker_embedding_channels = 128 +config.model_args.d_vector_dim = 256 + +# prosody embedding +config.model_args.use_prosody_encoder = True +config.model_args.prosody_embedding_dim = 256 + +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech_test " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +emotion_id = "ljspeech-3" +continue_speakers_path = os.path.join(continue_path, "speakers.json") +continue_emotion_path = os.path.join(continue_path, "speakers.json") + + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --emotion_idx {emotion_id} --speakers_file_path {continue_speakers_path} --emotions_file_path {continue_emotion_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path)