diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 3d996d28..d2992f4e 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -312,6 +312,7 @@ class VitsDataset(TTSDataset): "wav_file": wav_filename, "speaker_name": item["speaker_name"], "language_name": item["language"], + "emotion_name": item["emotion_name"], "pitch": f0, "alignments": alignments, @@ -398,6 +399,7 @@ class VitsDataset(TTSDataset): "pitch": pitch, "speaker_names": batch["speaker_name"], "language_names": batch["language_name"], + "emotion_names": batch["emotion_name"], "audio_files": batch["wav_file"], "raw_text": batch["raw_text"], "alignments": padded_alignments, @@ -1395,7 +1397,6 @@ class Vits(BaseTTS): # speaker embedding if self.args.use_speaker_embedding and sid is not None: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] - # emotion embedding if self.args.use_emotion_embedding and eid is not None and eg is None: eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] @@ -2248,6 +2249,13 @@ class Vits(BaseTTS): emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names] emotion_ids = torch.LongTensor(emotion_ids) + if ( + self.emotion_manager is not None + and self.emotion_manager.ids + and self.args.use_emotion_embedding + ): + emotion_ids = torch.LongTensor([self.emotion_manager.ids[en] for en in batch["emotion_names"]]) + batch["language_ids"] = language_ids batch["d_vectors"] = d_vectors batch["speaker_ids"] = speaker_ids diff --git a/TTS/tts/utils/emotions.py b/TTS/tts/utils/emotions.py index 80c01e12..0b88d1ef 100644 --- a/TTS/tts/utils/emotions.py +++ b/TTS/tts/utils/emotions.py @@ -51,6 +51,7 @@ class EmotionManager(EmbeddingManager): def __init__( self, + data_items: List[List[Any]] = None, embeddings_file_path: str = "", emotion_id_file_path: str = "", encoder_model_path: str = "", @@ -65,6 +66,9 @@ class EmotionManager(EmbeddingManager): use_cuda=use_cuda, ) + if data_items: + self.set_ids_from_data(data_items, parse_key="emotion_name") + @property def num_emotions(self): return len(self.ids) @@ -75,10 +79,25 @@ class EmotionManager(EmbeddingManager): @staticmethod def parse_ids_from_data(items: List, parse_key: str) -> Any: - raise NotImplementedError + """Parse IDs from data samples retured by `load_tts_samples()`. + + Args: + items (list): Data sampled returned by `load_tts_samples()`. + parse_key (str): The key to being used to parse the data. + Returns: + Tuple[Dict]: speaker IDs. + """ + classes = sorted({item[parse_key] for item in items}) + ids = {name: i for i, name in enumerate(classes)} + return ids def set_ids_from_data(self, items: List, parse_key: str) -> Any: - raise NotImplementedError + """Set IDs from data samples. + + Args: + items (List): Data sampled returned by `load_tts_samples()`. + """ + self.ids = self.parse_ids_from_data(items, parse_key=parse_key) def get_emotions(self) -> List: return self.ids