From 060e0f9368eb6237cf330502b9869b4e87de6c12 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 31 Mar 2022 08:41:16 -0300 Subject: [PATCH] Add EmbeddingManager and BaseIDManager (#1374) --- TTS/bin/compute_embeddings.py | 6 +- TTS/bin/eval_encoder.py | 10 +- TTS/bin/extract_tts_spectrograms.py | 4 +- TTS/bin/synthesize.py | 4 +- TTS/bin/train_encoder.py | 4 +- TTS/encoder/utils/generic_utils.py | 2 +- TTS/server/server.py | 2 +- TTS/tts/models/base_tts.py | 32 +- TTS/tts/models/glow_tts.py | 2 +- TTS/tts/models/vits.py | 44 ++- TTS/tts/utils/languages.py | 45 +-- TTS/tts/utils/managers.py | 285 ++++++++++++++++ TTS/tts/utils/speakers.py | 308 ++---------------- TTS/utils/synthesizer.py | 18 +- .../multilingual/vits_tts/train_vits_tts.py | 2 +- recipes/vctk/fast_pitch/train_fast_pitch.py | 2 +- recipes/vctk/fast_speech/train_fast_speech.py | 2 +- recipes/vctk/glow_tts/train_glow_tts.py | 2 +- .../vctk/speedy_speech/train_speedy_speech.py | 2 +- .../vctk/tacotron-DDC/train_tacotron-DDC.py | 2 +- .../vctk/tacotron2-DDC/train_tacotron2-ddc.py | 2 +- recipes/vctk/tacotron2/train_tacotron2.py | 2 +- recipes/vctk/vits/train_vits.py | 2 +- tests/aux_tests/test_speaker_manager.py | 22 +- tests/tts_tests/test_glow_tts.py | 2 +- tests/tts_tests/test_vits.py | 6 +- tests/zoo_tests/test_models.py | 2 +- 27 files changed, 412 insertions(+), 404 deletions(-) create mode 100644 TTS/tts/utils/managers.py diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index d7a2c5f6..b62d603a 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -49,7 +49,7 @@ encoder_manager = SpeakerManager( use_cuda=args.use_cuda, ) -class_name_key = encoder_manager.speaker_encoder_config.class_name_key +class_name_key = encoder_manager.encoder_config.class_name_key # compute speaker embeddings speaker_mapping = {} @@ -63,10 +63,10 @@ for idx, wav_file in enumerate(tqdm(wav_files)): wav_file_name = os.path.basename(wav_file) if args.old_file is not None and wav_file_name in encoder_manager.clip_ids: # get the embedding from the old file - embedd = encoder_manager.get_d_vector_by_clip(wav_file_name) + embedd = encoder_manager.get_embedding_by_clip(wav_file_name) else: # extract the embedding - embedd = encoder_manager.compute_d_vector_from_clip(wav_file) + embedd = encoder_manager.compute_embedding_from_clip(wav_file) # create speaker_mapping if target dataset is defined speaker_mapping[wav_file_name] = {} diff --git a/TTS/bin/eval_encoder.py b/TTS/bin/eval_encoder.py index 089f3645..7f9fdf93 100644 --- a/TTS/bin/eval_encoder.py +++ b/TTS/bin/eval_encoder.py @@ -11,8 +11,8 @@ from TTS.tts.utils.speakers import SpeakerManager def compute_encoder_accuracy(dataset_items, encoder_manager): - class_name_key = encoder_manager.speaker_encoder_config.class_name_key - map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, "map_classid_to_classname", None) + class_name_key = encoder_manager.encoder_config.class_name_key + map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None) class_acc_dict = {} @@ -22,13 +22,13 @@ def compute_encoder_accuracy(dataset_items, encoder_manager): wav_file = item["audio_file"] # extract the embedding - embedd = encoder_manager.compute_d_vector_from_clip(wav_file) - if encoder_manager.speaker_encoder_criterion is not None and map_classid_to_classname is not None: + embedd = encoder_manager.compute_embedding_from_clip(wav_file) + if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None: embedding = torch.FloatTensor(embedd).unsqueeze(0) if encoder_manager.use_cuda: embedding = embedding.cuda() - class_id = encoder_manager.speaker_encoder_criterion.softmax.inference(embedding).item() + class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item() predicted_label = map_classid_to_classname[str(class_id)] else: predicted_label = None diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index fa63c46a..a0dd0549 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -37,8 +37,8 @@ def setup_loader(ap, r, verbose=False): precompute_num_workers=0, use_noise_augment=False, verbose=verbose, - speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None, - d_vector_mapping=speaker_manager.d_vectors if c.use_d_vector_file else None, + speaker_id_mapping=speaker_manager.ids if c.use_speaker_embedding else None, + d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None, ) if c.use_phonemes and c.compute_input_seq_cache: diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index eb166bc8..6247b2a4 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -278,7 +278,7 @@ If you don't specify any models, then it uses LJSpeech based English model. print( " > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." ) - print(synthesizer.tts_model.speaker_manager.speaker_ids) + print(synthesizer.tts_model.speaker_manager.ids) return # query langauge ids of a multi-lingual model. @@ -286,7 +286,7 @@ If you don't specify any models, then it uses LJSpeech based English model. print( " > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." ) - print(synthesizer.tts_model.language_manager.language_id_mapping) + print(synthesizer.tts_model.language_manager.ids) return # check the arguments against a multi-speaker model. diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index b8d38bac..d28f188e 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -12,7 +12,7 @@ from trainer.torch import NoamLR from trainer.trainer_utils import get_optimizer from TTS.encoder.dataset import EncoderDataset -from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model +from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model from TTS.encoder.utils.samplers import PerfectBatchSampler from TTS.encoder.utils.training import init_training from TTS.encoder.utils.visual import plot_embeddings @@ -258,7 +258,7 @@ def main(args): # pylint: disable=redefined-outer-name global train_classes ap = AudioProcessor(**c.audio) - model = setup_speaker_encoder_model(c) + model = setup_encoder_model(c) optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model) diff --git a/TTS/encoder/utils/generic_utils.py b/TTS/encoder/utils/generic_utils.py index 19c00582..91a896f6 100644 --- a/TTS/encoder/utils/generic_utils.py +++ b/TTS/encoder/utils/generic_utils.py @@ -125,7 +125,7 @@ def to_camel(text): return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) -def setup_speaker_encoder_model(config: "Coqpit"): +def setup_encoder_model(config: "Coqpit"): if config.model_params["model_name"].lower() == "lstm": model = LSTMSpeakerEncoder( config.model_params["input_dim"], diff --git a/TTS/server/server.py b/TTS/server/server.py index aef507fd..fd53e76d 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -143,7 +143,7 @@ def index(): "index.html", show_details=args.show_details, use_multi_speaker=use_multi_speaker, - speaker_ids=speaker_manager.speaker_ids if speaker_manager is not None else None, + speaker_ids=speaker_manager.ids if speaker_manager is not None else None, use_gst=use_gst, ) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 945c031f..652b77dd 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -136,18 +136,18 @@ class BaseTTS(BaseTrainerModel): if hasattr(self, "speaker_manager"): if config.use_d_vector_file: if speaker_name is None: - d_vector = self.speaker_manager.get_random_d_vector() + d_vector = self.speaker_manager.get_random_embeddings() else: - d_vector = self.speaker_manager.get_d_vector_by_speaker(speaker_name) + d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) elif config.use_speaker_embedding: if speaker_name is None: - speaker_id = self.speaker_manager.get_random_speaker_id() + speaker_id = self.speaker_manager.get_random_id() else: - speaker_id = self.speaker_manager.speaker_ids[speaker_name] + speaker_id = self.speaker_manager.ids[speaker_name] # get language id if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: - language_id = self.language_manager.language_id_mapping[language_name] + language_id = self.language_manager.ids[language_name] return { "text": text, @@ -279,23 +279,19 @@ class BaseTTS(BaseTrainerModel): # setup multi-speaker attributes if hasattr(self, "speaker_manager") and self.speaker_manager is not None: if hasattr(config, "model_args"): - speaker_id_mapping = ( - self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None - ) - d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None + speaker_id_mapping = self.speaker_manager.ids if config.model_args.use_speaker_embedding else None + d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None config.use_d_vector_file = config.model_args.use_d_vector_file else: - speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None - d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None + speaker_id_mapping = self.speaker_manager.ids if config.use_speaker_embedding else None + d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None else: speaker_id_mapping = None d_vector_mapping = None # setup multi-lingual attributes if hasattr(self, "language_manager") and self.language_manager is not None: - language_id_mapping = ( - self.language_manager.language_id_mapping if self.args.use_language_embedding else None - ) + language_id_mapping = self.language_manager.ids if self.args.use_language_embedding else None else: language_id_mapping = None @@ -352,13 +348,13 @@ class BaseTTS(BaseTrainerModel): d_vector = None if self.config.use_d_vector_file: - d_vector = [self.speaker_manager.d_vectors[name]["embedding"] for name in self.speaker_manager.d_vectors] + d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings] d_vector = (random.sample(sorted(d_vector), 1),) aux_inputs = { "speaker_id": None if not self.config.use_speaker_embedding - else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1), + else random.sample(sorted(self.speaker_manager.ids.values()), 1), "d_vector": d_vector, "style_wav": None, # TODO: handle GST style input } @@ -405,7 +401,7 @@ class BaseTTS(BaseTrainerModel): """Save the speaker.json and language_ids.json at the beginning of the training. Also update both paths.""" if self.speaker_manager is not None: output_path = os.path.join(trainer.output_path, "speakers.json") - self.speaker_manager.save_speaker_ids_to_file(output_path) + self.speaker_manager.save_ids_to_file(output_path) trainer.config.speakers_file = output_path # some models don't have `model_args` set if hasattr(trainer.config, "model_args"): @@ -416,7 +412,7 @@ class BaseTTS(BaseTrainerModel): if hasattr(self, "language_manager") and self.language_manager is not None: output_path = os.path.join(trainer.output_path, "language_ids.json") - self.language_manager.save_language_ids_to_file(output_path) + self.language_manager.save_ids_to_file(output_path) trainer.config.language_ids_file = output_path if hasattr(trainer.config, "model_args"): trainer.config.model_args.language_ids_file = output_path diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index fea570a6..7c0f95e1 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -124,7 +124,7 @@ class GlowTTS(BaseTTS): ) if self.speaker_manager is not None: assert ( - config.d_vector_dim == self.speaker_manager.d_vector_dim + config.d_vector_dim == self.speaker_manager.embedding_dim ), " [!] d-vector dimension mismatch b/w config and speaker manager." # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 87d559fc..943b9eae 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -652,28 +652,28 @@ class Vits(BaseTTS): # TODO: make this a function if self.args.use_speaker_encoder_as_loss: - if self.speaker_manager.speaker_encoder is None and ( + if self.speaker_manager.encoder is None and ( not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path ): raise RuntimeError( " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" ) - self.speaker_manager.speaker_encoder.eval() + self.speaker_manager.encoder.eval() print(" > External Speaker Encoder Loaded !!") if ( - hasattr(self.speaker_manager.speaker_encoder, "audio_config") - and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"] + hasattr(self.speaker_manager.encoder, "audio_config") + and self.config.audio["sample_rate"] != self.speaker_manager.encoder.audio_config["sample_rate"] ): self.audio_transform = torchaudio.transforms.Resample( orig_freq=self.audio_config["sample_rate"], - new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], + new_freq=self.speaker_manager.encoder.audio_config["sample_rate"], ) # pylint: disable=W0101,W0105 self.audio_transform = torchaudio.transforms.Resample( orig_freq=self.config.audio.sample_rate, - new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], + new_freq=self.speaker_manager.encoder.audio_config["sample_rate"], ) def _init_speaker_embedding(self): @@ -887,7 +887,7 @@ class Vits(BaseTTS): pad_short=True, ) - if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: + if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None: # concate generated and GT waveforms wavs_batch = torch.cat((wav_seg, o), dim=0) @@ -896,7 +896,7 @@ class Vits(BaseTTS): if self.audio_transform is not None: wavs_batch = self.audio_transform(wavs_batch) - pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True) + pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True) # split generated and GT speaker embeddings gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) @@ -1223,18 +1223,18 @@ class Vits(BaseTTS): if hasattr(self, "speaker_manager"): if config.use_d_vector_file: if speaker_name is None: - d_vector = self.speaker_manager.get_random_d_vector() + d_vector = self.speaker_manager.get_random_embeddings() else: - d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=None, randomize=False) + d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) elif config.use_speaker_embedding: if speaker_name is None: - speaker_id = self.speaker_manager.get_random_speaker_id() + speaker_id = self.speaker_manager.get_random_id() else: - speaker_id = self.speaker_manager.speaker_ids[speaker_name] + speaker_id = self.speaker_manager.ids[speaker_name] # get language id if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: - language_id = self.language_manager.language_id_mapping[language_name] + language_id = self.language_manager.ids[language_name] return { "text": text, @@ -1289,26 +1289,22 @@ class Vits(BaseTTS): d_vectors = None # get numerical speaker ids from speaker names - if self.speaker_manager is not None and self.speaker_manager.speaker_ids and self.args.use_speaker_embedding: - speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in batch["speaker_names"]] + if self.speaker_manager is not None and self.speaker_manager.ids and self.args.use_speaker_embedding: + speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]] if speaker_ids is not None: speaker_ids = torch.LongTensor(speaker_ids) batch["speaker_ids"] = speaker_ids # get d_vectors from audio file names - if self.speaker_manager is not None and self.speaker_manager.d_vectors and self.args.use_d_vector_file: - d_vector_mapping = self.speaker_manager.d_vectors + if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file: + d_vector_mapping = self.speaker_manager.embeddings d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]] d_vectors = torch.FloatTensor(d_vectors) # get language ids from language names - if ( - self.language_manager is not None - and self.language_manager.language_id_mapping - and self.args.use_language_embedding - ): - language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]] + if self.language_manager is not None and self.language_manager.ids and self.args.use_language_embedding: + language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]] if language_ids is not None: language_ids = torch.LongTensor(language_ids) @@ -1490,7 +1486,7 @@ class Vits(BaseTTS): language_manager = LanguageManager.init_from_config(config) if config.model_args.speaker_encoder_model_path: - speaker_manager.init_speaker_encoder( + speaker_manager.init_encoder( config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path ) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index 7decabb0..9b5e2007 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -1,6 +1,5 @@ -import json import os -from typing import Dict, List +from typing import Any, Dict, List import fsspec import numpy as np @@ -8,9 +7,10 @@ import torch from coqpit import Coqpit from TTS.config import check_config_and_model_args +from TTS.tts.utils.managers import BaseIDManager -class LanguageManager: +class LanguageManager(BaseIDManager): """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information in a way that can be queried by language. @@ -25,37 +25,23 @@ class LanguageManager: >>> language_id_mapper = manager.language_ids """ - language_id_mapping: Dict = {} - def __init__( self, language_ids_file_path: str = "", config: Coqpit = None, ): - self.language_id_mapping = {} - if language_ids_file_path: - self.set_language_ids_from_file(language_ids_file_path) + super().__init__(id_file_path=language_ids_file_path) if config: self.set_language_ids_from_config(config) - @staticmethod - def _load_json(json_file_path: str) -> Dict: - with fsspec.open(json_file_path, "r") as f: - return json.load(f) - - @staticmethod - def _save_json(json_file_path: str, data: dict) -> None: - with fsspec.open(json_file_path, "w") as f: - json.dump(data, f, indent=4) - @property def num_languages(self) -> int: - return len(list(self.language_id_mapping.keys())) + return len(list(self.ids.keys())) @property def language_names(self) -> List: - return list(self.language_id_mapping.keys()) + return list(self.ids.keys()) @staticmethod def parse_language_ids_from_config(c: Coqpit) -> Dict: @@ -79,25 +65,24 @@ class LanguageManager: """Set language IDs from config samples. Args: - items (List): Data sampled returned by `load_meta_data()`. + c (Coqpit): Config. """ - self.language_id_mapping = self.parse_language_ids_from_config(c) + self.ids = self.parse_language_ids_from_config(c) - def set_language_ids_from_file(self, file_path: str) -> None: - """Load language ids from a json file. + @staticmethod + def parse_ids_from_data(items: List, parse_key: str) -> Any: + raise NotImplementedError - Args: - file_path (str): Path to the target json file. - """ - self.language_id_mapping = self._load_json(file_path) + def set_ids_from_data(self, items: List, parse_key: str) -> Any: + raise NotImplementedError - def save_language_ids_to_file(self, file_path: str) -> None: + def save_ids_to_file(self, file_path: str) -> None: """Save language IDs to a json file. Args: file_path (str): Path to the output file. """ - self._save_json(file_path, self.language_id_mapping) + self._save_json(file_path, self.ids) @staticmethod def init_from_config(config: Coqpit) -> "LanguageManager": diff --git a/TTS/tts/utils/managers.py b/TTS/tts/utils/managers.py new file mode 100644 index 00000000..85ed53cc --- /dev/null +++ b/TTS/tts/utils/managers.py @@ -0,0 +1,285 @@ +import json +import random +from typing import Any, Dict, List, Tuple, Union + +import fsspec +import numpy as np +import torch + +from TTS.config import load_config +from TTS.encoder.utils.generic_utils import setup_encoder_model +from TTS.utils.audio import AudioProcessor + + +class BaseIDManager: + """Base `ID` Manager class. Every new `ID` manager must inherit this. + It defines common `ID` manager specific functions. + """ + + def __init__(self, id_file_path: str = ""): + self.ids = {} + + if id_file_path: + self.load_ids_from_file(id_file_path) + + @staticmethod + def _load_json(json_file_path: str) -> Dict: + with fsspec.open(json_file_path, "r") as f: + return json.load(f) + + @staticmethod + def _save_json(json_file_path: str, data: dict) -> None: + with fsspec.open(json_file_path, "w") as f: + json.dump(data, f, indent=4) + + def set_ids_from_data(self, items: List, parse_key: str) -> None: + """Set IDs from data samples. + + Args: + items (List): Data sampled returned by `load_tts_samples()`. + """ + self.ids = self.parse_ids_from_data(items, parse_key=parse_key) + + def load_ids_from_file(self, file_path: str) -> None: + """Set IDs from a file. + + Args: + file_path (str): Path to the file. + """ + self.ids = self._load_json(file_path) + + def save_ids_to_file(self, file_path: str) -> None: + """Save IDs to a json file. + + Args: + file_path (str): Path to the output file. + """ + self._save_json(file_path, self.ids) + + def get_random_id(self) -> Any: + """Get a random embedding. + + Args: + + Returns: + np.ndarray: embedding. + """ + if self.ids: + return self.ids[random.choices(list(self.ids.keys()))[0]] + + return None + + @staticmethod + def parse_ids_from_data(items: List, parse_key: str) -> Tuple[Dict]: + """Parse IDs from data samples retured by `load_tts_samples()`. + + Args: + items (list): Data sampled returned by `load_tts_samples()`. + parse_key (str): The key to being used to parse the data. + Returns: + Tuple[Dict]: speaker IDs. + """ + classes = sorted({item[parse_key] for item in items}) + ids = {name: i for i, name in enumerate(classes)} + return ids + + +class EmbeddingManager(BaseIDManager): + """Base `Embedding` Manager class. Every new `Embedding` manager must inherit this. + It defines common `Embedding` manager specific functions. + """ + + def __init__( + self, + embedding_file_path: str = "", + id_file_path: str = "", + encoder_model_path: str = "", + encoder_config_path: str = "", + use_cuda: bool = False, + ): + super().__init__(id_file_path=id_file_path) + + self.embeddings = {} + self.embeddings_by_names = {} + self.clip_ids = [] + self.encoder = None + self.encoder_ap = None + self.use_cuda = use_cuda + + if embedding_file_path: + self.load_embeddings_from_file(embedding_file_path) + + if encoder_model_path and encoder_config_path: + self.init_encoder(encoder_model_path, encoder_config_path) + + @property + def embedding_dim(self): + """Dimensionality of embeddings. If embeddings are not loaded, returns zero.""" + if self.embeddings: + return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"]) + return 0 + + def save_embeddings_to_file(self, file_path: str) -> None: + """Save embeddings to a json file. + + Args: + file_path (str): Path to the output file. + """ + self._save_json(file_path, self.embeddings) + + def load_embeddings_from_file(self, file_path: str) -> None: + """Load embeddings from a json file. + + Args: + file_path (str): Path to the target json file. + """ + self.embeddings = self._load_json(file_path) + + speakers = sorted({x["name"] for x in self.embeddings.values()}) + self.ids = {name: i for i, name in enumerate(speakers)} + + self.clip_ids = list(set(sorted(clip_name for clip_name in self.embeddings.keys()))) + # cache embeddings_by_names for fast inference using a bigger speakers.json + self.embeddings_by_names = self.get_embeddings_by_names() + + def get_embedding_by_clip(self, clip_idx: str) -> List: + """Get embedding by clip ID. + + Args: + clip_idx (str): Target clip ID. + + Returns: + List: embedding as a list. + """ + return self.embeddings[clip_idx]["embedding"] + + def get_embeddings_by_name(self, idx: str) -> List[List]: + """Get all embeddings of a speaker. + + Args: + idx (str): Target name. + + Returns: + List[List]: all the embeddings of the given speaker. + """ + return self.embeddings_by_names[idx] + + def get_embeddings_by_names(self) -> Dict: + """Get all embeddings by names. + + Returns: + Dict: all the embeddings of each speaker. + """ + embeddings_by_names = {} + for x in self.embeddings.values(): + if x["name"] not in embeddings_by_names.keys(): + embeddings_by_names[x["name"]] = [x["embedding"]] + else: + embeddings_by_names[x["name"]].append(x["embedding"]) + return embeddings_by_names + + def get_mean_embedding(self, idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray: + """Get mean embedding of a idx. + + Args: + idx (str): Target name. + num_samples (int, optional): Number of samples to be averaged. Defaults to None. + randomize (bool, optional): Pick random `num_samples` of embeddings. Defaults to False. + + Returns: + np.ndarray: Mean embedding. + """ + embeddings = self.get_embeddings_by_name(idx) + if num_samples is None: + embeddings = np.stack(embeddings).mean(0) + else: + assert len(embeddings) >= num_samples, f" [!] {idx} has number of samples < {num_samples}" + if randomize: + embeddings = np.stack(random.choices(embeddings, k=num_samples)).mean(0) + else: + embeddings = np.stack(embeddings[:num_samples]).mean(0) + return embeddings + + def get_random_embedding(self) -> Any: + """Get a random embedding. + + Args: + + Returns: + np.ndarray: embedding. + """ + if self.embeddings: + return self.embeddings[random.choices(list(self.embeddings.keys()))[0]]["embedding"] + + return None + + def get_clips(self) -> List: + return sorted(self.embeddings.keys()) + + def init_encoder(self, model_path: str, config_path: str) -> None: + """Initialize a speaker encoder model. + + Args: + model_path (str): Model file path. + config_path (str): Model config file path. + """ + self.encoder_config = load_config(config_path) + self.encoder = setup_encoder_model(self.encoder_config) + self.encoder_criterion = self.encoder.load_checkpoint( + self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda + ) + self.encoder_ap = AudioProcessor(**self.encoder_config.audio) + + def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list: + """Compute a embedding from a given audio file. + + Args: + wav_file (Union[str, List[str]]): Target file path. + + Returns: + list: Computed embedding. + """ + + def _compute(wav_file: str): + waveform = self.encoder_ap.load_wav(wav_file, sr=self.encoder_ap.sample_rate) + if not self.encoder_config.model_params.get("use_torch_spec", False): + m_input = self.encoder_ap.melspectrogram(waveform) + m_input = torch.from_numpy(m_input) + else: + m_input = torch.from_numpy(waveform) + + if self.use_cuda: + m_input = m_input.cuda() + m_input = m_input.unsqueeze(0) + embedding = self.encoder.compute_embedding(m_input) + return embedding + + if isinstance(wav_file, list): + # compute the mean embedding + embeddings = None + for wf in wav_file: + embedding = _compute(wf) + if embeddings is None: + embeddings = embedding + else: + embeddings += embedding + return (embeddings / len(wav_file))[0].tolist() + embedding = _compute(wav_file) + return embedding[0].tolist() + + def compute_embeddings(self, feats: Union[torch.Tensor, np.ndarray]) -> List: + """Compute embedding from features. + + Args: + feats (Union[torch.Tensor, np.ndarray]): Input features. + + Returns: + List: computed embedding. + """ + if isinstance(feats, np.ndarray): + feats = torch.from_numpy(feats) + if feats.ndim == 2: + feats = feats.unsqueeze(0) + if self.use_cuda: + feats = feats.cuda() + return self.encoder.compute_embedding(feats) diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 0227412d..284d0179 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -1,19 +1,17 @@ import json import os -import random -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Union import fsspec import numpy as np import torch from coqpit import Coqpit -from TTS.config import get_from_config_or_model_args_with_default, load_config -from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model -from TTS.utils.audio import AudioProcessor +from TTS.config import get_from_config_or_model_args_with_default +from TTS.tts.utils.managers import EmbeddingManager -class SpeakerManager: +class SpeakerManager(EmbeddingManager): """Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information in a way that can be queried by speaker or clip. @@ -50,7 +48,7 @@ class SpeakerManager: >>> # load a sample audio and compute embedding >>> waveform = ap.load_wav(sample_wav_path) >>> mel = ap.melspectrogram(waveform) - >>> d_vector = manager.compute_d_vector(mel.T) + >>> d_vector = manager.compute_embeddings(mel.T) """ def __init__( @@ -62,279 +60,27 @@ class SpeakerManager: encoder_config_path: str = "", use_cuda: bool = False, ): - - self.d_vectors = {} - self.speaker_ids = {} - self.d_vectors_by_speakers = {} - self.clip_ids = [] - self.speaker_encoder = None - self.speaker_encoder_ap = None - self.use_cuda = use_cuda + super().__init__( + embedding_file_path=d_vectors_file_path, + id_file_path=speaker_id_file_path, + encoder_model_path=encoder_model_path, + encoder_config_path=encoder_config_path, + use_cuda=use_cuda, + ) if data_items: - self.speaker_ids, _ = self.parse_speakers_from_data(data_items) - - if d_vectors_file_path: - self.set_d_vectors_from_file(d_vectors_file_path) - - if speaker_id_file_path: - self.set_speaker_ids_from_file(speaker_id_file_path) - - if encoder_model_path and encoder_config_path: - self.init_speaker_encoder(encoder_model_path, encoder_config_path) - - @staticmethod - def _load_json(json_file_path: str) -> Dict: - with fsspec.open(json_file_path, "r") as f: - return json.load(f) - - @staticmethod - def _save_json(json_file_path: str, data: dict) -> None: - with fsspec.open(json_file_path, "w") as f: - json.dump(data, f, indent=4) + self.set_ids_from_data(data_items, parse_key="speaker_name") @property def num_speakers(self): - return len(self.speaker_ids) + return len(self.ids) @property def speaker_names(self): - return list(self.speaker_ids.keys()) - - @property - def d_vector_dim(self): - """Dimensionality of d_vectors. If d_vectors are not loaded, returns zero.""" - if self.d_vectors: - return len(self.d_vectors[list(self.d_vectors.keys())[0]]["embedding"]) - return 0 - - @staticmethod - def parse_speakers_from_data(items: list) -> Tuple[Dict, int]: - """Parse speaker IDs from data samples retured by `load_tts_samples()`. - - Args: - items (list): Data sampled returned by `load_tts_samples()`. - - Returns: - Tuple[Dict, int]: speaker IDs and number of speakers. - """ - speakers = sorted({item["speaker_name"] for item in items}) - speaker_ids = {name: i for i, name in enumerate(speakers)} - num_speakers = len(speaker_ids) - return speaker_ids, num_speakers - - def set_speaker_ids_from_data(self, items: List) -> None: - """Set speaker IDs from data samples. - - Args: - items (List): Data sampled returned by `load_tts_samples()`. - """ - self.speaker_ids, _ = self.parse_speakers_from_data(items) - - def set_speaker_ids_from_file(self, file_path: str) -> None: - """Set speaker IDs from a file. - - Args: - file_path (str): Path to the file. - """ - self.speaker_ids = self._load_json(file_path) - - def save_speaker_ids_to_file(self, file_path: str) -> None: - """Save speaker IDs to a json file. - - Args: - file_path (str): Path to the output file. - """ - self._save_json(file_path, self.speaker_ids) - - def save_d_vectors_to_file(self, file_path: str) -> None: - """Save d_vectors to a json file. - - Args: - file_path (str): Path to the output file. - """ - self._save_json(file_path, self.d_vectors) - - def set_d_vectors_from_file(self, file_path: str) -> None: - """Load d_vectors from a json file. - - Args: - file_path (str): Path to the target json file. - """ - self.d_vectors = self._load_json(file_path) - - speakers = sorted({x["name"] for x in self.d_vectors.values()}) - self.speaker_ids = {name: i for i, name in enumerate(speakers)} - - self.clip_ids = list(set(sorted(clip_name for clip_name in self.d_vectors.keys()))) - # cache d_vectors_by_speakers for fast inference using a bigger speakers.json - self.d_vectors_by_speakers = self.get_d_vectors_by_speakers() - - def get_d_vector_by_clip(self, clip_idx: str) -> List: - """Get d_vector by clip ID. - - Args: - clip_idx (str): Target clip ID. - - Returns: - List: d_vector as a list. - """ - return self.d_vectors[clip_idx]["embedding"] - - def get_d_vectors_by_speaker(self, speaker_idx: str) -> List[List]: - """Get all d_vectors of a speaker. - - Args: - speaker_idx (str): Target speaker ID. - - Returns: - List[List]: all the d_vectors of the given speaker. - """ - return self.d_vectors_by_speakers[speaker_idx] - - def get_d_vectors_by_speakers(self) -> Dict: - """Get all d_vectors by speaker. - - Returns: - Dict: all the d_vectors of each speaker. - """ - d_vectors_by_speakers = {} - for x in self.d_vectors.values(): - if x["name"] not in d_vectors_by_speakers.keys(): - d_vectors_by_speakers[x["name"]] = [x["embedding"]] - else: - d_vectors_by_speakers[x["name"]].append(x["embedding"]) - return d_vectors_by_speakers - - def get_mean_d_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray: - """Get mean d_vector of a speaker ID. - - Args: - speaker_idx (str): Target speaker ID. - num_samples (int, optional): Number of samples to be averaged. Defaults to None. - randomize (bool, optional): Pick random `num_samples` of d_vectors. Defaults to False. - - Returns: - np.ndarray: Mean d_vector. - """ - d_vectors = self.get_d_vectors_by_speaker(speaker_idx) - if num_samples is None: - d_vectors = np.stack(d_vectors).mean(0) - else: - assert len(d_vectors) >= num_samples, f" [!] speaker {speaker_idx} has number of samples < {num_samples}" - if randomize: - d_vectors = np.stack(random.choices(d_vectors, k=num_samples)).mean(0) - else: - d_vectors = np.stack(d_vectors[:num_samples]).mean(0) - return d_vectors - - def get_random_speaker_id(self) -> Any: - """Get a random d_vector. - - Args: - - Returns: - np.ndarray: d_vector. - """ - if self.speaker_ids: - return self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]] - - return None - - def get_random_d_vector(self) -> Any: - """Get a random D ID. - - Args: - - Returns: - np.ndarray: d_vector. - """ - if self.d_vectors: - return self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"] - - return None + return list(self.ids.keys()) def get_speakers(self) -> List: - return self.speaker_ids - - def get_clips(self) -> List: - return sorted(self.d_vectors.keys()) - - def init_speaker_encoder(self, model_path: str, config_path: str) -> None: - """Initialize a speaker encoder model. - - Args: - model_path (str): Model file path. - config_path (str): Model config file path. - """ - self.speaker_encoder_config = load_config(config_path) - self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config) - self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint( - self.speaker_encoder_config, model_path, eval=True, use_cuda=self.use_cuda - ) - self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio) - - def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list: - """Compute a d_vector from a given audio file. - - Args: - wav_file (Union[str, List[str]]): Target file path. - - Returns: - list: Computed d_vector. - """ - - def _compute(wav_file: str): - waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate) - if not self.speaker_encoder_config.model_params.get("use_torch_spec", False): - m_input = self.speaker_encoder_ap.melspectrogram(waveform) - m_input = torch.from_numpy(m_input) - else: - m_input = torch.from_numpy(waveform) - - if self.use_cuda: - m_input = m_input.cuda() - m_input = m_input.unsqueeze(0) - d_vector = self.speaker_encoder.compute_embedding(m_input) - return d_vector - - if isinstance(wav_file, list): - # compute the mean d_vector - d_vectors = None - for wf in wav_file: - d_vector = _compute(wf) - if d_vectors is None: - d_vectors = d_vector - else: - d_vectors += d_vector - return (d_vectors / len(wav_file))[0].tolist() - d_vector = _compute(wav_file) - return d_vector[0].tolist() - - def compute_d_vector(self, feats: Union[torch.Tensor, np.ndarray]) -> List: - """Compute d_vector from features. - - Args: - feats (Union[torch.Tensor, np.ndarray]): Input features. - - Returns: - List: computed d_vector. - """ - if isinstance(feats, np.ndarray): - feats = torch.from_numpy(feats) - if feats.ndim == 2: - feats = feats.unsqueeze(0) - if self.use_cuda: - feats = feats.cuda() - return self.speaker_encoder.compute_embedding(feats) - - def run_umap(self): - # TODO: implement speaker encoder - raise NotImplementedError - - def plot_embeddings(self): - # TODO: implement speaker encoder - raise NotImplementedError + return self.ids @staticmethod def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager": @@ -420,7 +166,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, speaker_manager = SpeakerManager() if c.use_speaker_embedding: if data is not None: - speaker_manager.set_speaker_ids_from_data(data) + speaker_manager.set_ids_from_data(data, parse_key="speaker_name") if restore_path: speakers_file = _set_file_path(restore_path) # restoring speaker manager from a previous run. @@ -432,27 +178,27 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, raise RuntimeError( "You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file" ) - speaker_manager.load_d_vectors_file(c.d_vector_file) - speaker_manager.set_d_vectors_from_file(speakers_file) + speaker_manager.load_embeddings_from_file(c.d_vector_file) + speaker_manager.load_embeddings_from_file(speakers_file) elif not c.use_d_vector_file: # restor speaker manager with speaker ID file. - speaker_ids_from_data = speaker_manager.speaker_ids - speaker_manager.set_speaker_ids_from_file(speakers_file) + speaker_ids_from_data = speaker_manager.ids + speaker_manager.load_ids_from_file(speakers_file) assert all( - speaker in speaker_manager.speaker_ids for speaker in speaker_ids_from_data + speaker in speaker_manager.ids for speaker in speaker_ids_from_data ), " [!] You cannot introduce new speakers to a pre-trained model." elif c.use_d_vector_file and c.d_vector_file: # new speaker manager with external speaker embeddings. - speaker_manager.set_d_vectors_from_file(c.d_vector_file) + speaker_manager.load_embeddings_from_file(c.d_vector_file) elif c.use_d_vector_file and not c.d_vector_file: raise "use_d_vector_file is True, so you need pass a external speaker embedding file." elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file: # new speaker manager with speaker IDs file. - speaker_manager.set_speaker_ids_from_file(c.speakers_file) + speaker_manager.load_ids_from_file(c.speakers_file) if speaker_manager.num_speakers > 0: print( " > Speaker manager is loaded with {} speakers: {}".format( - speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids) + speaker_manager.num_speakers, ", ".join(speaker_manager.ids) ) ) @@ -461,9 +207,9 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, out_file_path = os.path.join(out_path, "speakers.json") print(f" > Saving `speakers.json` to {out_file_path}.") if c.use_d_vector_file and c.d_vector_file: - speaker_manager.save_d_vectors_to_file(out_file_path) + speaker_manager.save_embeddings_to_file(out_file_path) else: - speaker_manager.save_speaker_ids_to_file(out_file_path) + speaker_manager.save_ids_to_file(out_file_path) return speaker_manager diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index eef4086c..1a49f0b0 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -122,7 +122,7 @@ class Synthesizer(object): self.tts_model.cuda() if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager"): - self.tts_model.speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config) + self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config) def _set_speaker_encoder_paths_from_tts_config(self): """Set the encoder paths from the tts model config for models with speaker encoders.""" @@ -212,17 +212,17 @@ class Synthesizer(object): # handle multi-speaker speaker_embedding = None speaker_id = None - if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "speaker_ids"): + if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"): if speaker_name and isinstance(speaker_name, str): if self.tts_config.use_d_vector_file: # get the average speaker embedding from the saved d_vectors. - speaker_embedding = self.tts_model.speaker_manager.get_mean_d_vector( + speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding( speaker_name, num_samples=None, randomize=False ) speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim] else: # get speaker idx from the speaker name - speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_name] + speaker_id = self.tts_model.speaker_manager.ids[speaker_name] elif not speaker_name and not speaker_wav: raise ValueError( @@ -244,7 +244,7 @@ class Synthesizer(object): hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None ): if language_name and isinstance(language_name, str): - language_id = self.tts_model.language_manager.language_id_mapping[language_name] + language_id = self.tts_model.language_manager.ids[language_name] elif not language_name: raise ValueError( @@ -260,7 +260,7 @@ class Synthesizer(object): # compute a new d_vector from the given clip. if speaker_wav is not None: - speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav) + speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav) use_gl = self.vocoder_model is None @@ -319,7 +319,7 @@ class Synthesizer(object): if reference_speaker_name and isinstance(reference_speaker_name, str): if self.tts_config.use_d_vector_file: # get the speaker embedding from the saved d_vectors. - reference_speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker( + reference_speaker_embedding = self.tts_model.speaker_manager.get_embeddings_by_name( reference_speaker_name )[0] reference_speaker_embedding = np.array(reference_speaker_embedding)[ @@ -327,9 +327,9 @@ class Synthesizer(object): ] # [1 x embedding_dim] else: # get speaker idx from the speaker name - reference_speaker_id = self.tts_model.speaker_manager.speaker_ids[reference_speaker_name] + reference_speaker_id = self.tts_model.speaker_manager.ids[reference_speaker_name] else: - reference_speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip( + reference_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip( reference_wav ) diff --git a/recipes/multilingual/vits_tts/train_vits_tts.py b/recipes/multilingual/vits_tts/train_vits_tts.py index 94692f00..0e650ade 100644 --- a/recipes/multilingual/vits_tts/train_vits_tts.py +++ b/recipes/multilingual/vits_tts/train_vits_tts.py @@ -119,7 +119,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it maps speaker-id to speaker-name in the model and data-loader speaker_manager = SpeakerManager() -speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") config.model_args.num_speakers = speaker_manager.num_speakers language_manager = LanguageManager(config=config) diff --git a/recipes/vctk/fast_pitch/train_fast_pitch.py b/recipes/vctk/fast_pitch/train_fast_pitch.py index 05cdc72a..c39932da 100644 --- a/recipes/vctk/fast_pitch/train_fast_pitch.py +++ b/recipes/vctk/fast_pitch/train_fast_pitch.py @@ -81,7 +81,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it maps speaker-id to speaker-name in the model and data-loader speaker_manager = SpeakerManager() -speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") config.model_args.num_speakers = speaker_manager.num_speakers # init model diff --git a/recipes/vctk/fast_speech/train_fast_speech.py b/recipes/vctk/fast_speech/train_fast_speech.py index a294272a..a3249de1 100644 --- a/recipes/vctk/fast_speech/train_fast_speech.py +++ b/recipes/vctk/fast_speech/train_fast_speech.py @@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it maps speaker-id to speaker-name in the model and data-loader speaker_manager = SpeakerManager() -speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") config.model_args.num_speakers = speaker_manager.num_speakers # init model diff --git a/recipes/vctk/glow_tts/train_glow_tts.py b/recipes/vctk/glow_tts/train_glow_tts.py index 0bf686b1..23c02efc 100644 --- a/recipes/vctk/glow_tts/train_glow_tts.py +++ b/recipes/vctk/glow_tts/train_glow_tts.py @@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it maps speaker-id to speaker-name in the model and data-loader speaker_manager = SpeakerManager() -speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") config.num_speakers = speaker_manager.num_speakers # init model diff --git a/recipes/vctk/speedy_speech/train_speedy_speech.py b/recipes/vctk/speedy_speech/train_speedy_speech.py index 4208a9b6..bcd0105a 100644 --- a/recipes/vctk/speedy_speech/train_speedy_speech.py +++ b/recipes/vctk/speedy_speech/train_speedy_speech.py @@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it maps speaker-id to speaker-name in the model and data-loader speaker_manager = SpeakerManager() -speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") config.model_args.num_speakers = speaker_manager.num_speakers # init model diff --git a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py index d67038a4..36e28ed7 100644 --- a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py +++ b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py @@ -82,7 +82,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it mainly handles speaker-id to speaker-name for the model and the data-loader speaker_manager = SpeakerManager() -speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") # init model model = Tacotron(config, ap, tokenizer, speaker_manager) diff --git a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py index b860df85..d04d91c0 100644 --- a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py +++ b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py @@ -88,7 +88,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it mainly handles speaker-id to speaker-name for the model and the data-loader speaker_manager = SpeakerManager() -speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") # init model model = Tacotron2(config, ap, tokenizer, speaker_manager) diff --git a/recipes/vctk/tacotron2/train_tacotron2.py b/recipes/vctk/tacotron2/train_tacotron2.py index d27dd78c..5a0e157a 100644 --- a/recipes/vctk/tacotron2/train_tacotron2.py +++ b/recipes/vctk/tacotron2/train_tacotron2.py @@ -88,7 +88,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it mainly handles speaker-id to speaker-name for the model and the data-loader speaker_manager = SpeakerManager() -speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") # init model model = Tacotron2(config, ap, tokenizer, speaker_manager) diff --git a/recipes/vctk/vits/train_vits.py b/recipes/vctk/vits/train_vits.py index 61d60ca1..88fd7de9 100644 --- a/recipes/vctk/vits/train_vits.py +++ b/recipes/vctk/vits/train_vits.py @@ -89,7 +89,7 @@ train_samples, eval_samples = load_tts_samples( # init speaker manager for multi-speaker training # it maps speaker-id to speaker-name in the model and data-loader speaker_manager = SpeakerManager() -speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples) +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") config.model_args.num_speakers = speaker_manager.num_speakers # init model diff --git a/tests/aux_tests/test_speaker_manager.py b/tests/aux_tests/test_speaker_manager.py index 57ff6c50..7552e0a5 100644 --- a/tests/aux_tests/test_speaker_manager.py +++ b/tests/aux_tests/test_speaker_manager.py @@ -6,7 +6,7 @@ import torch from tests import get_tests_input_path from TTS.config import load_config -from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model +from TTS.encoder.utils.generic_utils import setup_encoder_model from TTS.encoder.utils.io import save_checkpoint from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor @@ -28,7 +28,7 @@ class SpeakerManagerTest(unittest.TestCase): config.audio.resample = True # create a dummy speaker encoder - model = setup_speaker_encoder_model(config) + model = setup_encoder_model(config) save_checkpoint(model, None, None, get_tests_input_path(), 0) # load audio processor and speaker encoder @@ -38,19 +38,19 @@ class SpeakerManagerTest(unittest.TestCase): # load a sample audio and compute embedding waveform = ap.load_wav(sample_wav_path) mel = ap.melspectrogram(waveform) - d_vector = manager.compute_d_vector(mel) + d_vector = manager.compute_embeddings(mel) assert d_vector.shape[1] == 256 # compute d_vector directly from an input file - d_vector = manager.compute_d_vector_from_clip(sample_wav_path) - d_vector2 = manager.compute_d_vector_from_clip(sample_wav_path) + d_vector = manager.compute_embedding_from_clip(sample_wav_path) + d_vector2 = manager.compute_embedding_from_clip(sample_wav_path) d_vector = torch.FloatTensor(d_vector) d_vector2 = torch.FloatTensor(d_vector2) assert d_vector.shape[0] == 256 assert (d_vector - d_vector2).sum() == 0.0 # compute d_vector from a list of wav files. - d_vector3 = manager.compute_d_vector_from_clip([sample_wav_path, sample_wav_path2]) + d_vector3 = manager.compute_embedding_from_clip([sample_wav_path, sample_wav_path2]) d_vector3 = torch.FloatTensor(d_vector3) assert d_vector3.shape[0] == 256 assert (d_vector - d_vector3).sum() != 0.0 @@ -62,14 +62,14 @@ class SpeakerManagerTest(unittest.TestCase): def test_speakers_file_processing(): manager = SpeakerManager(d_vectors_file_path=d_vectors_file_path) print(manager.num_speakers) - print(manager.d_vector_dim) + print(manager.embedding_dim) print(manager.clip_ids) - d_vector = manager.get_d_vector_by_clip(manager.clip_ids[0]) + d_vector = manager.get_embedding_by_clip(manager.clip_ids[0]) assert len(d_vector) == 256 - d_vectors = manager.get_d_vectors_by_speaker(manager.speaker_names[0]) + d_vectors = manager.get_embeddings_by_name(manager.speaker_names[0]) assert len(d_vectors[0]) == 256 - d_vector1 = manager.get_mean_d_vector(manager.speaker_names[0], num_samples=2, randomize=True) + d_vector1 = manager.get_mean_embedding(manager.speaker_names[0], num_samples=2, randomize=True) assert len(d_vector1) == 256 - d_vector2 = manager.get_mean_d_vector(manager.speaker_names[0], num_samples=2, randomize=False) + d_vector2 = manager.get_mean_embedding(manager.speaker_names[0], num_samples=2, randomize=False) assert len(d_vector2) == 256 assert np.sum(np.array(d_vector1) - np.array(d_vector2)) != 0 diff --git a/tests/tts_tests/test_glow_tts.py b/tests/tts_tests/test_glow_tts.py index 2783e4bd..2a723f10 100644 --- a/tests/tts_tests/test_glow_tts.py +++ b/tests/tts_tests/test_glow_tts.py @@ -86,7 +86,7 @@ class TestGlowTTS(unittest.TestCase): model = GlowTTS(config) model.speaker_manager = speaker_manager model.init_multispeaker(config) - self.assertEqual(model.c_in_channels, speaker_manager.d_vector_dim) + self.assertEqual(model.c_in_channels, speaker_manager.embedding_dim) self.assertEqual(model.num_speakers, speaker_manager.num_speakers) def test_unlock_act_norm_layers(self): diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 05adb9ed..de683c81 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -7,7 +7,7 @@ from trainer.logging.tensorboard_logger import TensorboardLogger from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path from TTS.config import load_config -from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model +from TTS.encoder.utils.generic_utils import setup_encoder_model from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec from TTS.tts.utils.speakers import SpeakerManager @@ -242,9 +242,9 @@ class TestVits(unittest.TestCase): speaker_encoder_config = load_config(SPEAKER_ENCODER_CONFIG) speaker_encoder_config.model_params["use_torch_spec"] = True - speaker_encoder = setup_speaker_encoder_model(speaker_encoder_config).to(device) + speaker_encoder = setup_encoder_model(speaker_encoder_config).to(device) speaker_manager = SpeakerManager() - speaker_manager.speaker_encoder = speaker_encoder + speaker_manager.encoder = speaker_encoder args = VitsArgs( language_ids_file=LANG_FILE, diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index 63d9e7ca..e614ce74 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -38,7 +38,7 @@ def test_run_all_models(): language_manager = LanguageManager(language_ids_file_path=language_files[0]) language_id = language_manager.language_names[0] - speaker_id = list(speaker_manager.speaker_ids.keys())[0] + speaker_id = list(speaker_manager.ids.keys())[0] run_cli( f"tts --model_name {model_name} " f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" '