From f996afedb0b0be325359bbe1ab758cdc05ef44b5 Mon Sep 17 00:00:00 2001 From: Edresson Date: Fri, 13 Aug 2021 19:58:56 -0300 Subject: [PATCH] Implement multilingual dataloader support --- TTS/config/shared_configs.py | 3 +++ TTS/trainer.py | 14 ++++++++++++++ TTS/tts/datasets/__init__.py | 6 ++++++ TTS/tts/datasets/dataset.py | 23 +++++++++++++++++++---- TTS/tts/models/base_tts.py | 31 ++++++++++++++++++++++++++----- TTS/tts/utils/text/cleaners.py | 8 ++++++++ 6 files changed, 76 insertions(+), 9 deletions(-) diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index d91bf2b6..f1ea2e0f 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -199,6 +199,7 @@ class BaseDatasetConfig(Coqpit): path: str = "" meta_file_train: str = "" ununsed_speakers: List[str] = None + language: str = "" meta_file_val: str = "" meta_file_attn_mask: str = "" @@ -335,6 +336,8 @@ class BaseTrainingConfig(Coqpit): num_loader_workers: int = 0 num_eval_loader_workers: int = 0 use_noise_augment: bool = False + use_language_weighted_sampler: bool = False + # paths output_path: str = None # distributed diff --git a/TTS/trainer.py b/TTS/trainer.py index 2a2cfc46..2175875c 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -260,6 +260,20 @@ class Trainer: else: self.run_get_model(self.config, get_model) + if hasattr(self.model, "init_multilingual"): + self.model.init_multilingual(self.config, self.data_train + self.data_eval) + config = self.config.model_args if hasattr(self.config, "model_args") else self.config + # save speakers json + if config.use_language_embedding and self.model.language_manager.num_languages > 1: + self.model.language_manager.save_language_ids_to_file(os.path.join(self.output_path, "language_ids.json")) + if hasattr(self.config, "model_args"): + self.config.model_args["num_languages"] = self.model.language_manager.num_languages + else: + self.config.num_languages = self.model.language_manager.num_languages + + # update config file + copy_model_files(self.config, self.output_path, None) + # setup criterion self.criterion = self.get_criterion(self.model) diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index 741f92fd..3673e188 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -68,16 +68,22 @@ def load_tts_samples( meta_file_train = dataset["meta_file_train"] meta_file_val = dataset["meta_file_val"] ununsed_speakers = dataset["ununsed_speakers"] + language = dataset["language"] + # setup the right data processor if formatter is None: formatter = _get_formatter_by_name(name) # load train set meta_data_train = formatter(root_path, meta_file_train, ununsed_speakers=ununsed_speakers) + # TODO: remove the loops and pass language as a parameter to preprocessor for faster load + meta_data_train = [[*item, language] for item in meta_data_train] + print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") # load evaluation split if set if eval_split: if meta_file_val: meta_data_eval = formatter(root_path, meta_file_val, ununsed_speakers=ununsed_speakers) + meta_data_eval = [[*item, language] for item in meta_data_eval] else: meta_data_eval, meta_data_train = split_dataset(meta_data_train) meta_data_eval_all += meta_data_eval diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 04314bab..7ba97eba 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -37,6 +37,7 @@ class TTSDataset(Dataset): enable_eos_bos: bool = False, speaker_id_mapping: Dict = None, d_vector_mapping: Dict = None, + language_id_mapping: Dict = None, use_noise_augment: bool = False, verbose: bool = False, ): @@ -122,6 +123,7 @@ class TTSDataset(Dataset): self.enable_eos_bos = enable_eos_bos self.speaker_id_mapping = speaker_id_mapping self.d_vector_mapping = d_vector_mapping + self.language_id_mapping = language_id_mapping self.use_noise_augment = use_noise_augment self.verbose = verbose self.input_seq_computed = False @@ -197,10 +199,10 @@ class TTSDataset(Dataset): def load_data(self, idx): item = self.items[idx] - if len(item) == 4: - text, wav_file, speaker_name, attn_file = item + if len(item) == 5: + text, wav_file, speaker_name, language_name, attn_file = item else: - text, wav_file, speaker_name = item + text, wav_file, speaker_name, language_name = item attn = None raw_text = text @@ -218,7 +220,7 @@ class TTSDataset(Dataset): self.phoneme_cache_path, self.enable_eos_bos, self.cleaners, - self.phoneme_language, + language_name if language_name else self.phoneme_language, self.custom_symbols, self.characters, self.add_blank, @@ -260,6 +262,7 @@ class TTSDataset(Dataset): "attn": attn, "item_idx": self.items[idx][1], "speaker_name": speaker_name, + "language_name": language_name, "wav_file_name": os.path.basename(wav_file), } return sample @@ -413,6 +416,14 @@ class TTSDataset(Dataset): # convert list of dicts to dict of lists batch = {k: [dic[k] for dic in batch] for k in batch[0]} + speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing] + + # get language ids from language names + if self.language_id_mapping is not None: + language_names = [batch[idx]["language_name"] for idx in ids_sorted_decreasing] + language_ids = [self.language_id_mapping[ln] for ln in language_names] + else: + language_ids = None # get pre-computed d-vectors if self.d_vector_mapping is not None: wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing] @@ -466,6 +477,9 @@ class TTSDataset(Dataset): if speaker_ids is not None: speaker_ids = torch.LongTensor(speaker_ids) + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + # compute linear spectrogram if self.compute_linear_spec: linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] @@ -528,6 +542,7 @@ class TTSDataset(Dataset): "waveform": wav_padded, "raw_text": batch["raw_text"], "pitch": pitch, + "language_ids": language_ids } raise TypeError( diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 854526de..c55936a8 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -13,6 +13,7 @@ from TTS.model import BaseModel from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager +from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text import make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -73,9 +74,18 @@ class BaseTTS(BaseModel): def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager: return get_speaker_manager(config, restore_path, data, out_path) - def init_multispeaker(self, config: Coqpit): - """Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding - vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension. + def init_multispeaker(self, config: Coqpit, data: List = None): + """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining + `in_channels` size of the connected layers. + + This implementation yields 3 possible outcomes: + + 1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing. + 2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512. + 3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of + `config.d_vector_dim` or 512. + + You can override this function for new models. Args: config (Coqpit): Model configuration. @@ -122,6 +132,7 @@ class BaseTTS(BaseModel): attn_mask = batch["attns"] waveform = batch["waveform"] pitch = batch["pitch"] + language_ids = batch["language_ids"] max_text_length = torch.max(text_lengths.float()) max_spec_length = torch.max(mel_lengths.float()) @@ -169,6 +180,7 @@ class BaseTTS(BaseModel): "item_idx": item_idx, "waveform": waveform, "pitch": pitch, + "language_ids": language_ids, } def get_data_loader( @@ -199,7 +211,12 @@ class BaseTTS(BaseModel): if hasattr(self, "make_symbols"): custom_symbols = self.make_symbols(self.config) - # init dataset + if hasattr(self, "language_manager"): + language_id_mapping = self.language_manager.language_id_mapping if self.args.use_language_embedding else None + else: + language_id_mapping = None + + # init dataloader dataset = TTSDataset( outputs_per_step=config.r if "r" in config else 1, text_cleaner=config.text_cleaner, @@ -223,6 +240,7 @@ class BaseTTS(BaseModel): verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, + language_id_mapping=language_id_mapping, ) # pre-compute phonemes @@ -267,8 +285,11 @@ class BaseTTS(BaseModel): # sampler for DDP sampler = DistributedSampler(dataset) if num_gpus > 1 else None + if sampler is None: + if getattr(config, "use_language_weighted_sampler", False): + sampler = get_language_weighted_sampler(dataset.items) + print(" > Using Language weighted sampler") - # init dataloader loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py index 4b041ed8..71155ebc 100644 --- a/TTS/tts/utils/text/cleaners.py +++ b/TTS/tts/utils/text/cleaners.py @@ -135,3 +135,11 @@ def phoneme_cleaners(text): text = remove_aux_symbols(text) text = collapse_whitespace(text) return text + +def multilingual_cleaners(text): + '''Pipeline for multilingual text''' + text = lowercase(text) + text = replace_symbols(text, lang=None) + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text \ No newline at end of file