From 6f95522edf5d066f4237467508ea3f313d68940e Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 15 Mar 2022 01:16:48 +0000 Subject: [PATCH] Add Emotion Support for the VITS model --- TTS/bin/synthesize.py | 10 ++ TTS/tts/models/base_tts.py | 22 +++ TTS/tts/models/vits.py | 157 +++++++++++++++--- TTS/tts/utils/emotions.py | 41 ++--- TTS/tts/utils/synthesis.py | 14 ++ TTS/utils/synthesizer.py | 30 +++- ...est_vits_speaker_emb_with_emotion_train.py | 83 +++++++++ 7 files changed, 311 insertions(+), 46 deletions(-) create mode 100644 tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 4e93535a..900f5df7 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -152,6 +152,7 @@ If you don't specify any models, then it uses LJSpeech based English model. # args for multi-speaker synthesis parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) + parser.add_argument("--emotions_file_path", type=str, help="JSON file for emotion model.", default=None) parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None) parser.add_argument( "--speaker_idx", @@ -165,6 +166,12 @@ If you don't specify any models, then it uses LJSpeech based English model. help="Target language ID for a multi-lingual TTS model.", default=None, ) + parser.add_argument( + "--emotion_idx", + type=str, + help="Target emotion ID.", + default=None, + ) parser.add_argument( "--speaker_wav", nargs="+", @@ -254,6 +261,7 @@ If you don't specify any models, then it uses LJSpeech based English model. model_path = args.model_path config_path = args.config_path speakers_file_path = args.speakers_file_path + emotions_file_path = args.emotions_file_path language_ids_file_path = args.language_ids_file_path if args.vocoder_path is not None: @@ -269,6 +277,7 @@ If you don't specify any models, then it uses LJSpeech based English model. model_path, config_path, speakers_file_path, + emotions_file_path, language_ids_file_path, vocoder_path, vocoder_config_path, @@ -315,6 +324,7 @@ If you don't specify any models, then it uses LJSpeech based English model. style_wav=args.capacitron_style_wav, style_text=args.capacitron_style_text, reference_speaker_name=args.reference_speaker_idx, + emotion_name=args.emotion_idx, ) # save the results diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index c71872d3..e1158b42 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -428,3 +428,25 @@ class BaseTTS(BaseTrainerModel): trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) print(f" > `language_ids.json` is saved to {output_path}.") print(" > `language_ids_file` is updated in the config.json.") + + if hasattr(self, "emotion_manager") and self.emotion_manager is not None: + output_path = os.path.join(trainer.output_path, "emotions.json") + + if hasattr(trainer.config, "model_args"): + if trainer.config.model_args.use_emotion_embedding and not trainer.config.model_args.external_emotions_embs_file: + self.emotion_manager.save_ids_to_file(output_path) + trainer.config.model_args.emotions_ids_file = output_path + else: + self.emotion_manager.save_embeddings_to_file(output_path) + trainer.config.model_args.external_emotions_embs_file = output_path + else: + if trainer.config.use_emotion_embedding and not trainer.config.external_emotions_embs_file: + self.emotion_manager.save_ids_to_file(output_path) + trainer.config.emotions_ids_file = output_path + else: + self.emotion_manager.save_embeddings_to_file(output_path) + trainer.config.external_emotions_embs_file = output_path + + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + print(f" > `emotions.json` is saved to {output_path}.") + print(" > `emotions_ids_file` or `external_emotions_embs_file` is updated in the config.json.") \ No newline at end of file diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index a6b1c743..f27f1f6d 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -25,6 +25,7 @@ from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.emotions import EmotionManager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer @@ -300,7 +301,7 @@ class VitsDataset(TTSDataset): "waveform_lens": wav_lens, # (B) "waveform_rel_lens": wav_rel_lens, "speaker_names": batch["speaker_name"], - "language_names": batch["language_name"], + "f": batch["language_name"], "audio_files": batch["wav_file"], "raw_text": batch["raw_text"], } @@ -529,6 +530,15 @@ class VitsArgs(Coqpit): speaker_embedding_channels: int = 256 use_d_vector_file: bool = False d_vector_dim: int = 0 + + # use emotion embeddings + use_emotion_embedding: bool = False + use_external_emotions_embeddings: bool = False + emotions_ids_file: str = None + external_emotions_embs_file: str = None + emotion_embedding_dim: int = 0 + num_emotions: int = 0 + detach_dp_input: bool = True use_language_embedding: bool = False embedded_language_dim: int = 4 @@ -584,13 +594,14 @@ class Vits(BaseTTS): tokenizer: "TTSTokenizer" = None, speaker_manager: SpeakerManager = None, language_manager: LanguageManager = None, + emotion_manager: EmotionManager = None, ): - super().__init__(config, ap, tokenizer, speaker_manager, language_manager) self.init_multispeaker(config) self.init_multilingual(config) self.init_upsampling() + self.init_emotion(config, emotion_manager) self.length_scale = self.args.length_scale self.noise_scale = self.args.noise_scale @@ -619,7 +630,7 @@ class Vits(BaseTTS): kernel_size=self.args.kernel_size_posterior_encoder, dilation_rate=self.args.dilation_rate_posterior_encoder, num_layers=self.args.num_layers_posterior_encoder, - cond_channels=self.embedded_speaker_dim, + cond_channels=self.cond_embedding_dim, ) self.flow = ResidualCouplingBlocks( @@ -628,7 +639,7 @@ class Vits(BaseTTS): kernel_size=self.args.kernel_size_flow, dilation_rate=self.args.dilation_rate_flow, num_layers=self.args.num_layers_flow, - cond_channels=self.embedded_speaker_dim, + cond_channels=self.cond_embedding_dim, ) if self.args.use_sdp: @@ -638,7 +649,7 @@ class Vits(BaseTTS): 3, self.args.dropout_p_duration_predictor, 4, - cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, + cond_channels=self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0, language_emb_dim=self.embedded_language_dim, ) else: @@ -647,7 +658,7 @@ class Vits(BaseTTS): 256, 3, self.args.dropout_p_duration_predictor, - cond_channels=self.embedded_speaker_dim, + cond_channels=self.cond_embedding_dim, language_emb_dim=self.embedded_language_dim, ) @@ -661,7 +672,7 @@ class Vits(BaseTTS): self.args.upsample_initial_channel_decoder, self.args.upsample_rates_decoder, inference_padding=0, - cond_channels=self.embedded_speaker_dim, + cond_channels=self.cond_embedding_dim, conv_pre_weight_norm=False, conv_post_weight_norm=False, conv_post_bias=False, @@ -683,7 +694,7 @@ class Vits(BaseTTS): config (Coqpit): Model configuration. data (List, optional): Dataset items to infer number of speakers. Defaults to None. """ - self.embedded_speaker_dim = 0 + self.cond_embedding_dim = 0 self.num_speakers = self.args.num_speakers self.audio_transform = None @@ -726,14 +737,14 @@ class Vits(BaseTTS): # pylint: disable=attribute-defined-outside-init if self.num_speakers > 0: print(" > initialization of speaker-embedding layers.") - self.embedded_speaker_dim = self.args.speaker_embedding_channels - self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + self.cond_embedding_dim += self.args.speaker_embedding_channels + self.emb_g = nn.Embedding(self.num_speakers, self.args.speaker_embedding_channels) def _init_d_vector(self): # pylint: disable=attribute-defined-outside-init if hasattr(self, "emb_g"): raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") - self.embedded_speaker_dim = self.args.d_vector_dim + self.cond_embedding_dim += self.args.d_vector_dim def init_multilingual(self, config: Coqpit): """Initialize multilingual modules of a model. @@ -785,9 +796,36 @@ class Vits(BaseTTS): raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") print(" > Text Encoder was reinit.") + def init_emotion(self, config: Coqpit, emotion_manager: EmotionManager): + # pylint: disable=attribute-defined-outside-init + """Initialize emotion modules of a model. A model can be trained either with a emotion embedding layer + or with external `embeddings` computed from a emotion encoder model. + + You must provide a `emotion_manager` at initialization to set up the emotion modules. + + Args: + config (Coqpit): Model configuration. + emotion_manager (Coqpit): Emotion Manager. + """ + self.emotion_manager = emotion_manager + self.num_emotions = self.args.num_emotions + + if self.emotion_manager: + self.num_emotions = self.emotion_manager.num_emotions + + if self.args.use_emotion_embedding: + if self.num_emotions > 0: + print(" > initialization of emotion-embedding layers.") + self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim) + self.cond_embedding_dim += self.args.emotion_embedding_dim + + if self.args.use_external_emotions_embeddings: + self.cond_embedding_dim += self.args.emotion_embedding_dim + def get_aux_input(self, aux_input: Dict): - sid, g, lid = self._set_cond_input(aux_input) - return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} + sid, g, lid, eid, eg = self._set_cond_input(aux_input) + return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid, + "emotion_embeddings": eg, "emotion_ids": eid} def _freeze_layers(self): if self.args.freeze_encoder: @@ -817,7 +855,7 @@ class Vits(BaseTTS): @staticmethod def _set_cond_input(aux_input: Dict): """Set the speaker conditioning input based on the multi-speaker mode.""" - sid, g, lid = None, None, None + sid, g, lid, eid, eg = None, None, None, None, None if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: sid = aux_input["speaker_ids"] if sid.ndim == 0: @@ -832,7 +870,17 @@ class Vits(BaseTTS): if lid.ndim == 0: lid = lid.unsqueeze_(0) - return sid, g, lid + if "emotion_ids" in aux_input and aux_input["emotion_ids"] is not None: + eid = aux_input["emotion_ids"] + if eid.ndim == 0: + eid = eid.unsqueeze_(0) + + if "emotion_embeddings" in aux_input and aux_input["emotion_embeddings"] is not None: + eg = F.normalize(aux_input["emotion_embeddings"]).unsqueeze(-1) + if eg.ndim == 2: + eg = eg.unsqueeze_(0) + + return sid, g, lid, eid, eg def _set_speaker_input(self, aux_input: Dict): d_vectors = aux_input.get("d_vectors", None) @@ -906,7 +954,7 @@ class Vits(BaseTTS): y: torch.tensor, y_lengths: torch.tensor, waveform: torch.tensor, - aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, + aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None, "emotion_embeddings": None, "emotion_ids": None}, ) -> Dict: """Forward pass of the model. @@ -946,11 +994,19 @@ class Vits(BaseTTS): - syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` """ outputs = {} - sid, g, lid = self._set_cond_input(aux_input) + sid, g, lid, eid, eg = self._set_cond_input(aux_input) # speaker embedding if self.args.use_speaker_embedding and sid is not None: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + # emotion embedding + if self.args.use_emotion_embedding and eid is not None and eg is None: + eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] + + # concat the emotion embedding and speaker embedding + if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): + g = torch.cat([g, eg], dim=1) # [b, h1+h1, 1] + # language embedding lang_emb = None if self.args.use_language_embedding and lid is not None: @@ -1028,7 +1084,7 @@ class Vits(BaseTTS): @torch.no_grad() def inference( - self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None} + self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "emotion_embeddings": None, "emotion_ids": None} ): # pylint: disable=dangerous-default-value """ Note: @@ -1048,13 +1104,21 @@ class Vits(BaseTTS): - m_p: :math:`[B, C, T_dec]` - logs_p: :math:`[B, C, T_dec]` """ - sid, g, lid = self._set_cond_input(aux_input) + sid, g, lid, eid, eg = self._set_cond_input(aux_input) x_lengths = self._set_x_lengths(x, aux_input) # speaker embedding if self.args.use_speaker_embedding and sid is not None: g = self.emb_g(sid).unsqueeze(-1) + # emotion embedding + if self.args.use_emotion_embedding and eid is not None and eg is None: + eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] + + # concat the emotion embedding and speaker embedding + if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): + g = torch.cat([g, eg], dim=1) # [b, h1+h1, 1] + # language embedding lang_emb = None if self.args.use_language_embedding and lid is not None: @@ -1187,6 +1251,8 @@ class Vits(BaseTTS): d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] language_ids = batch["language_ids"] + emotion_embeddings = batch["emotion_embeddings"] + emotion_ids = batch["emotion_ids"] waveform = batch["waveform"] # generator pass @@ -1196,7 +1262,8 @@ class Vits(BaseTTS): spec, spec_lens, waveform, - aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids, + "emotion_embeddings": emotion_embeddings, "emotion_ids": emotion_ids}, ) # cache tensors for the generator pass @@ -1322,7 +1389,7 @@ class Vits(BaseTTS): config = self.config # extract speaker and language info - text, speaker_name, style_wav, language_name = None, None, None, None + text, speaker_name, style_wav, language_name, emotion_name = None, None, None, None, None if isinstance(sentence_info, list): if len(sentence_info) == 1: @@ -1333,11 +1400,13 @@ class Vits(BaseTTS): text, speaker_name, style_wav = sentence_info elif len(sentence_info) == 4: text, speaker_name, style_wav, language_name = sentence_info + elif len(sentence_info) == 5: + text, speaker_name, style_wav, language_name, emotion_name = sentence_info else: text = sentence_info # get speaker id/d_vector - speaker_id, d_vector, language_id = None, None, None + speaker_id, d_vector, language_id, emotion_id, emotion_embedding = None, None, None, None, None if hasattr(self, "speaker_manager"): if config.use_d_vector_file: if speaker_name is None: @@ -1354,6 +1423,19 @@ class Vits(BaseTTS): if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: language_id = self.language_manager.ids[language_name] + # get emotion id/embedding + if hasattr(self, "emotion_manager"): + if config.use_external_emotions_embeddings: + if emotion_name is None: + emotion_embedding = self.emotion_manager.get_random_embeddings() + else: + emotion_embedding = self.emotion_manager.get_mean_embedding(emotion_name, num_samples=None, randomize=False) + elif config.use_emotion_embedding: + if emotion_name is None: + emotion_id = self.emotion_manager.get_random_id() + else: + emotion_id = self.emotion_manager.ids[emotion_name] + return { "text": text, "speaker_id": speaker_id, @@ -1361,6 +1443,8 @@ class Vits(BaseTTS): "d_vector": d_vector, "language_id": language_id, "language_name": language_name, + "emotion_embedding": emotion_embedding, + "emotion_ids": emotion_id } @torch.no_grad() @@ -1387,6 +1471,8 @@ class Vits(BaseTTS): d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], language_id=aux_inputs["language_id"], + emotion_embedding=aux_inputs["emotion_embedding"], + emotion_id=aux_inputs["emotion_ids"], use_griffin_lim=True, do_trim_silence=False, ).values() @@ -1401,10 +1487,12 @@ class Vits(BaseTTS): logger.test_figures(steps, outputs["figures"]) def format_batch(self, batch: Dict) -> Dict: - """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" + """Compute speaker, langugage IDs, d_vector and emotion embeddings for the batch if necessary.""" speaker_ids = None language_ids = None d_vectors = None + emotion_embeddings = None + emotion_ids = None # get numerical speaker ids from speaker names if self.speaker_manager is not None and self.speaker_manager.ids and self.args.use_speaker_embedding: @@ -1421,15 +1509,33 @@ class Vits(BaseTTS): d_vectors = torch.FloatTensor(d_vectors) # get language ids from language names - if self.language_manager is not None and self.language_manager.ids and self.args.use_language_embedding: + if ( + self.language_manager is not None + and self.language_manager.ids + and self.args.use_language_embedding + ): language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]] if language_ids is not None: language_ids = torch.LongTensor(language_ids) + # get emotion embedding + if self.emotion_manager is not None and self.emotion_manager.embeddings and self.args.use_external_emotions_embeddings: + emotion_mapping = self.emotion_manager.embeddings + emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]] + emotion_embeddings = torch.FloatTensor(emotion_embeddings) + + if self.emotion_manager is not None and self.emotion_manager.embeddings and self.args.use_emotion_embedding: + emotion_mapping = self.emotion_manager.embeddings + emotion_names = [emotion_mapping[w]["name"] for w in batch["audio_files"]] + emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names] + emotion_ids = torch.LongTensor(emotion_ids) + batch["language_ids"] = language_ids batch["d_vectors"] = d_vectors batch["speaker_ids"] = speaker_ids + batch["emotion_embeddings"] = emotion_embeddings + batch["emotion_ids"] = emotion_ids return batch def format_batch_on_device(self, batch): @@ -1643,12 +1749,13 @@ class Vits(BaseTTS): tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) language_manager = LanguageManager.init_from_config(config) + emotion_manager = EmotionManager.init_from_config(config) if config.model_args.speaker_encoder_model_path: speaker_manager.init_encoder( config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path ) - return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) + return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager) ################################## diff --git a/TTS/tts/utils/emotions.py b/TTS/tts/utils/emotions.py index 71cd71bf..d655ba03 100644 --- a/TTS/tts/utils/emotions.py +++ b/TTS/tts/utils/emotions.py @@ -1,10 +1,8 @@ import json import os -from typing import Any, Dict, List, Tuple, Union +from typing import Any, List import fsspec -import numpy as np -import torch from coqpit import Coqpit from TTS.config import get_from_config_or_model_args_with_default @@ -34,7 +32,7 @@ class EmotionManager(EmbeddingManager): computes the embeddings for a given clip or emotion. Args: - emo_embeddings_file_path (str, optional): Path to the metafile including x vectors. Defaults to "". + embeddings_file_path (str, optional): Path to the metafile including x vectors. Defaults to "". emotion_id_file_path (str, optional): Path to the metafile that maps emotion names to ids used by TTS models. Defaults to "". encoder_model_path (str, optional): Path to the emotion encoder model file. Defaults to "". @@ -50,14 +48,14 @@ class EmotionManager(EmbeddingManager): def __init__( self, - emo_embeddings_file_path: str = "", + embeddings_file_path: str = "", emotion_id_file_path: str = "", encoder_model_path: str = "", encoder_config_path: str = "", use_cuda: bool = False, ): super().__init__( - external_emotions_ids_file_path=emo_embeddings_file_path, + embedding_file_path=embeddings_file_path, id_file_path=emotion_id_file_path, encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path, @@ -98,11 +96,15 @@ class EmotionManager(EmbeddingManager): emotion_manager = EmotionManager( emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None) ) - - if get_from_config_or_model_args_with_default(config, "use_external_emotion_embedding", False): - if get_from_config_or_model_args_with_default(config, "external_emotions_ids_file", None): + elif get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None): emotion_manager = EmotionManager( - embeddings_file_path=get_from_config_or_model_args_with_default(config, "external_emotions_ids_file", None) + embeddings_file_path=get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None) + ) + + if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False): + if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None): + emotion_manager = EmotionManager( + embeddings_file_path=get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None) ) return emotion_manager @@ -157,26 +159,26 @@ def get_emotion_manager(c: Coqpit, restore_path: str = None, out_path: str = Non if c.use_external_emotions_embeddings: # restore emotion manager with the embedding file if not os.path.exists(emotions_ids_file): - print("WARNING: emotions.json was not found in restore_path, trying to use CONFIG.external_emotions_ids_file") - if not os.path.exists(c.external_emotions_ids_file): + print("WARNING: emotions.json was not found in restore_path, trying to use CONFIG.external_emotions_embs_file") + if not os.path.exists(c.external_emotions_embs_file): raise RuntimeError( - "You must copy the file emotions.json to restore_path, or set a valid file in CONFIG.external_emotions_ids_file" + "You must copy the file emotions.json to restore_path, or set a valid file in CONFIG.external_emotions_embs_file" ) - emotion_manager.load_embeddings_from_file(c.external_emotions_ids_file) + emotion_manager.load_embeddings_from_file(c.external_emotions_embs_file) emotion_manager.load_embeddings_from_file(emotions_ids_file) elif not c.use_external_emotions_embeddings: # restor emotion manager with emotion ID file. emotion_manager.load_ids_from_file(emotions_ids_file) - elif c.use_external_emotions_embeddings and c.external_emotions_ids_file: + elif c.use_external_emotions_embeddings and c.external_emotions_embs_file: # new emotion manager with external emotion embeddings. - emotion_manager.load_embeddings_from_file(c.external_emotions_ids_file) - elif c.use_external_emotions_embeddings and not c.external_emotions_ids_file: + emotion_manager.load_embeddings_from_file(c.external_emotions_embs_file) + elif c.use_external_emotions_embeddings and not c.external_emotions_embs_file: raise "use_external_emotions_embeddings is True, so you need pass a external emotion embedding file." elif c.use_emotion_embedding: if "emotions_ids_file" in c and c.emotions_ids_file: emotion_manager.load_ids_from_file(c.emotions_ids_file) else: # enable get ids from eternal embedding files - emotion_manager.load_embeddings_from_file(c.external_emotions_ids_file) + emotion_manager.load_embeddings_from_file(c.external_emotions_embs_file) if emotion_manager.num_emotions > 0: print( @@ -189,9 +191,8 @@ def get_emotion_manager(c: Coqpit, restore_path: str = None, out_path: str = Non if out_path: out_file_path = os.path.join(out_path, "emotions.json") print(f" > Saving `emotions.json` to {out_file_path}.") - if c.use_external_emotions_embeddings and c.external_emotions_ids_file: + if c.use_external_emotions_embeddings and c.external_emotions_embs_file: emotion_manager.save_embeddings_to_file(out_file_path) else: emotion_manager.save_ids_to_file(out_file_path) return emotion_manager - diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index a74300dc..ddc63cd8 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -29,6 +29,8 @@ def run_model_torch( style_text: str = None, d_vector: torch.Tensor = None, language_id: torch.Tensor = None, + emotion_id: torch.Tensor = None, + emotion_embedding: torch.Tensor = None, ) -> Dict: """Run a torch model for inference. It does not support batch inference. @@ -56,6 +58,8 @@ def run_model_torch( "style_mel": style_mel, "style_text": style_text, "language_ids": language_id, + "emotion_ids": emotion_id, + "emotion_embeddings": emotion_embedding, }, ) return outputs @@ -122,6 +126,8 @@ def synthesis( do_trim_silence=False, d_vector=None, language_id=None, + emotion_id=None, + emotion_embedding=None ): """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to the vocoder model. @@ -190,6 +196,12 @@ def synthesis( if language_id is not None: language_id = id_to_torch(language_id, cuda=use_cuda) + if emotion_id is not None: + emotion_id = id_to_torch(emotion_id, cuda=use_cuda) + + if emotion_embedding is not None: + emotion_embedding = embedding_to_torch(emotion_embedding, cuda=use_cuda) + if not isinstance(style_mel, dict): # GST or Capacitron style mel style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) @@ -212,6 +224,8 @@ def synthesis( style_text, d_vector=d_vector, language_id=language_id, + emotion_id=emotion_id, + emotion_embedding=emotion_embedding, ) model_outputs = outputs["model_outputs"] model_outputs = model_outputs[0].data.cpu().numpy() diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 9ce528a3..cb4308bb 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -22,6 +22,7 @@ class Synthesizer(object): tts_checkpoint: str, tts_config_path: str, tts_speakers_file: str = "", + tts_emotions_file: str = "", tts_languages_file: str = "", vocoder_checkpoint: str = "", vocoder_config: str = "", @@ -52,6 +53,7 @@ class Synthesizer(object): self.tts_checkpoint = tts_checkpoint self.tts_config_path = tts_config_path self.tts_speakers_file = tts_speakers_file + self.tts_emotions_file = tts_emotions_file self.tts_languages_file = tts_languages_file self.vocoder_checkpoint = vocoder_checkpoint self.vocoder_config = vocoder_config @@ -183,6 +185,7 @@ class Synthesizer(object): style_text=None, reference_wav=None, reference_speaker_name=None, + emotion_name=None, ) -> List[int]: """🐸 TTS magic. Run all the models and generate speech. @@ -240,7 +243,7 @@ class Synthesizer(object): "Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. " ) - # handle multi-lingaul + # handle multi-lingual language_id = None if self.tts_languages_file or ( hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None @@ -260,6 +263,29 @@ class Synthesizer(object): "Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. " ) + # handle emotion + emotion_embedding, emotion_id = None, None + if self.tts_emotions_file or hasattr(self.tts_model.emotion_manager, "ids"): + if emotion_name and isinstance(emotion_name, str): + if getattr(self.tts_config, "use_external_emotions_embeddings", False) or getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False): + # get the average speaker embedding from the saved embeddings. + emotion_embedding = self.tts_model.emotion_manager.get_mean_embedding(emotion_name, num_samples=None, randomize=False) + emotion_embedding = np.array(emotion_embedding)[None, :] # [1 x embedding_dim] + else: + # get speaker idx from the speaker name + speaker_id = self.tts_model.emotion_manager.ids[emotion_name] + elif not emotion_name: + raise ValueError( + " [!] Look like you use an emotion model. " + "You need to define either a `emotion_name` to use an emotion model." + ) + else: + if emotion_name: + raise ValueError( + f" [!] Missing emotion.json file path for selecting the emotion {emotion_name}." + "Define path for emotion.json if it is an emotion model or remove defined emotion idx. " + ) + # compute a new d_vector from the given clip. if speaker_wav is not None: speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav) @@ -280,6 +306,8 @@ class Synthesizer(object): use_griffin_lim=use_gl, d_vector=speaker_embedding, language_id=language_id, + emotion_embedding=emotion_embedding, + emotion_id=emotion_id, ) waveform = outputs["wav"] mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() diff --git a/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py new file mode 100644 index 00000000..6ce59c6c --- /dev/null +++ b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py @@ -0,0 +1,83 @@ +import glob +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.vits_config import VitsConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = VitsConfig( + batch_size=2, + eval_batch_size=2, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + ["Be a voice, not an echo.", "ljspeech-1", None, None, "ljspeech-1"], + ], +) +# set audio config +config.audio.do_trim_silence = True +config.audio.trim_db = 60 + +# active multispeaker d-vec mode +config.model_args.use_speaker_embedding = True +config.model_args.use_d_vector_file = False +config.model_args.d_vector_file = None +config.model_args.d_vector_dim = 256 + +# emotion +config.model_args.use_external_emotions_embeddings = True +config.model_args.use_emotion_embedding = False +config.model_args.emotion_embedding_dim = 256 +config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json" + + +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech_test " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +emotion_id = "ljspeech-1" +continue_speakers_path = os.path.join(continue_path, "speakers.json") +continue_emotion_path = os.path.join(continue_path, "speakers.json") + + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --emotion_idx {emotion_id} --speakers_file_path {continue_speakers_path} --emotions_file_path {continue_emotion_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path)