diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index ff1df767..c36ace93 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -69,10 +69,10 @@ class LanguageManager(BaseIDManager): self.ids = self.parse_language_ids_from_config(c) @staticmethod - def parse_ids_from_data(items: list) -> Any: + def parse_ids_from_data(items: List, parse_key: str) -> Any: raise NotImplementedError - def set_ids_from_data(self, items: List) -> Any: + def set_ids_from_data(self, items: List, parse_key: str) -> Any: raise NotImplementedError def save_ids_to_file(self, file_path: str) -> None: diff --git a/TTS/tts/utils/managers.py b/TTS/tts/utils/managers.py index e19e08a7..66a2824c 100644 --- a/TTS/tts/utils/managers.py +++ b/TTS/tts/utils/managers.py @@ -1,6 +1,6 @@ import json import random -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Tuple, Union import fsspec import numpy as np @@ -34,14 +34,13 @@ class BaseIDManager: with fsspec.open(json_file_path, "w") as f: json.dump(data, f, indent=4) - - def set_ids_from_data(self, items: List) -> None: + def set_ids_from_data(self, items: List, parse_key: str) -> None: """Set IDs from data samples. Args: items (List): Data sampled returned by `load_tts_samples()`. """ - self.ids, _ = self.parse_ids_from_data(items) + self.ids = self.parse_ids_from_data(items, parse_key=parse_key) def load_ids_from_file(self, file_path: str) -> None: """Set IDs from a file. @@ -73,9 +72,18 @@ class BaseIDManager: return None @staticmethod - def parse_ids_from_data(items: list) -> Any: - raise NotImplementedError + def parse_ids_from_data(items: List, parse_key: str) -> Tuple[Dict]: + """Parse IDs from data samples retured by `load_tts_samples()`. + Args: + items (list): Data sampled returned by `load_tts_samples()`. + parse_key (str): The key to being used to parse the data. + Returns: + Tuple[Dict]: speaker IDs. + """ + classes = sorted({item[parse_key] for item in items}) + ids = {name: i for i, name in enumerate(classes)} + return ids class EmbeddingManager(BaseIDManager): """ Base `Embedding` Manager class. Every new `Embedding` manager must inherit this. @@ -273,7 +281,3 @@ class EmbeddingManager(BaseIDManager): if self.use_cuda: feats = feats.cuda() return self.encoder.compute_embedding(feats) - - @staticmethod - def parse_ids_from_data(items: list) -> Any: - raise NotImplementedError diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index f3a2845f..d6577d5d 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -1,6 +1,6 @@ import json import os -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Union import fsspec import numpy as np @@ -68,7 +68,7 @@ class SpeakerManager(EmbeddingManager): ) if data_items: - self.ids, _ = self.parse_ids_from_data(data_items) + self.set_ids_from_data(data_items, parse_key="speaker_name") @property def num_speakers(self): @@ -78,21 +78,6 @@ class SpeakerManager(EmbeddingManager): def speaker_names(self): return list(self.ids.keys()) - @staticmethod - def parse_ids_from_data(items: list) -> Tuple[Dict, int]: - """Parse speaker IDs from data samples retured by `load_tts_samples()`. - - Args: - items (list): Data sampled returned by `load_tts_samples()`. - - Returns: - Tuple[Dict, int]: speaker IDs and number of speakers. - """ - speakers = sorted({item["speaker_name"] for item in items}) - speaker_ids = {name: i for i, name in enumerate(speakers)} - num_speakers = len(speaker_ids) - return speaker_ids, num_speakers - def get_speakers(self) -> List: return self.ids @@ -180,7 +165,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, speaker_manager = SpeakerManager() if c.use_speaker_embedding: if data is not None: - speaker_manager.set_ids_from_data(data) + speaker_manager.set_ids_from_data(data, parse_key="speaker_name") if restore_path: speakers_file = _set_file_path(restore_path) # restoring speaker manager from a previous run. diff --git a/recipes/multilingual/vits_tts/train_vits_tts.py b/recipes/multilingual/vits_tts/train_vits_tts.py index c0d22473..0e650ade 100644 --- a/recipes/multilingual/vits_tts/train_vits_tts.py +++ b/recipes/multilingual/vits_tts/train_vits_tts.py @@ -119,7 +119,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it maps speaker-id to speaker-name in the model and data-loader speaker_manager = SpeakerManager() -speaker_manager.set_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") config.model_args.num_speakers = speaker_manager.num_speakers language_manager = LanguageManager(config=config) diff --git a/recipes/vctk/fast_pitch/train_fast_pitch.py b/recipes/vctk/fast_pitch/train_fast_pitch.py index 91f4631e..c39932da 100644 --- a/recipes/vctk/fast_pitch/train_fast_pitch.py +++ b/recipes/vctk/fast_pitch/train_fast_pitch.py @@ -81,7 +81,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it maps speaker-id to speaker-name in the model and data-loader speaker_manager = SpeakerManager() -speaker_manager.set_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") config.model_args.num_speakers = speaker_manager.num_speakers # init model diff --git a/recipes/vctk/fast_speech/train_fast_speech.py b/recipes/vctk/fast_speech/train_fast_speech.py index 8df5b7d1..a3249de1 100644 --- a/recipes/vctk/fast_speech/train_fast_speech.py +++ b/recipes/vctk/fast_speech/train_fast_speech.py @@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it maps speaker-id to speaker-name in the model and data-loader speaker_manager = SpeakerManager() -speaker_manager.set_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") config.model_args.num_speakers = speaker_manager.num_speakers # init model diff --git a/recipes/vctk/glow_tts/train_glow_tts.py b/recipes/vctk/glow_tts/train_glow_tts.py index 7b33051a..23c02efc 100644 --- a/recipes/vctk/glow_tts/train_glow_tts.py +++ b/recipes/vctk/glow_tts/train_glow_tts.py @@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it maps speaker-id to speaker-name in the model and data-loader speaker_manager = SpeakerManager() -speaker_manager.set_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") config.num_speakers = speaker_manager.num_speakers # init model diff --git a/recipes/vctk/speedy_speech/train_speedy_speech.py b/recipes/vctk/speedy_speech/train_speedy_speech.py index c8ccf7ed..bcd0105a 100644 --- a/recipes/vctk/speedy_speech/train_speedy_speech.py +++ b/recipes/vctk/speedy_speech/train_speedy_speech.py @@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it maps speaker-id to speaker-name in the model and data-loader speaker_manager = SpeakerManager() -speaker_manager.set_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") config.model_args.num_speakers = speaker_manager.num_speakers # init model diff --git a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py index b4f48a7c..36e28ed7 100644 --- a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py +++ b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py @@ -82,7 +82,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it mainly handles speaker-id to speaker-name for the model and the data-loader speaker_manager = SpeakerManager() -speaker_manager.set_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") # init model model = Tacotron(config, ap, tokenizer, speaker_manager) diff --git a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py index 4ea7a9b5..d04d91c0 100644 --- a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py +++ b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py @@ -88,7 +88,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it mainly handles speaker-id to speaker-name for the model and the data-loader speaker_manager = SpeakerManager() -speaker_manager.set_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") # init model model = Tacotron2(config, ap, tokenizer, speaker_manager) diff --git a/recipes/vctk/tacotron2/train_tacotron2.py b/recipes/vctk/tacotron2/train_tacotron2.py index 3e8d49af..5a0e157a 100644 --- a/recipes/vctk/tacotron2/train_tacotron2.py +++ b/recipes/vctk/tacotron2/train_tacotron2.py @@ -88,7 +88,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it mainly handles speaker-id to speaker-name for the model and the data-loader speaker_manager = SpeakerManager() -speaker_manager.set_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") # init model model = Tacotron2(config, ap, tokenizer, speaker_manager) diff --git a/recipes/vctk/vits/train_vits.py b/recipes/vctk/vits/train_vits.py index 0d375773..88fd7de9 100644 --- a/recipes/vctk/vits/train_vits.py +++ b/recipes/vctk/vits/train_vits.py @@ -89,7 +89,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it maps speaker-id to speaker-name in the model and data-loader speaker_manager = SpeakerManager() -speaker_manager.set_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") config.model_args.num_speakers = speaker_manager.num_speakers # init model