From e4049aa31a0a27e49267613e76806ff4df4f23c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Nov 2021 13:33:21 +0100 Subject: [PATCH] Refactor TTSDataset to use TTSTokenizer --- TTS/tts/datasets/dataset.py | 113 +++++++++--------------------------- 1 file changed, 28 insertions(+), 85 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 546f012d..8c21d7d0 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -10,7 +10,7 @@ import tqdm from torch.utils.data import Dataset from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor -from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence +from TTS.tts.utils.text import TTSTokenizer from TTS.utils.audio import AudioProcessor @@ -18,23 +18,17 @@ class TTSDataset(Dataset): def __init__( self, outputs_per_step: int, - text_cleaner: list, compute_linear_spec: bool, ap: AudioProcessor, meta_data: List[Dict], + tokenizer: "TTSTokenizer" = None, compute_f0: bool = False, f0_cache_path: str = None, - characters: Dict = None, - custom_symbols: List = None, - add_blank: bool = False, return_wav: bool = False, batch_group_size: int = 0, min_seq_len: int = 0, max_seq_len: int = float("inf"), - use_phonemes: bool = False, phoneme_cache_path: str = None, - phoneme_language: str = "en-us", - enable_eos_bos: bool = False, speaker_id_mapping: Dict = None, d_vector_mapping: Dict = None, language_id_mapping: Dict = None, @@ -48,26 +42,19 @@ class TTSDataset(Dataset): Args: outputs_per_step (int): Number of time frames predicted per step. - text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs. - compute_linear_spec (bool): compute linear spectrogram if True. ap (TTS.tts.utils.AudioProcessor): Audio processor object. meta_data (list): List of dataset samples. + tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else + use the given. Defaults to None. + compute_f0 (bool): compute f0 if True. Defaults to False. f0_cache_path (str): Path to store f0 cache. Defaults to None. - characters (dict): `dict` of custom text characters used for converting texts to sequences. - - custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own - set of symbols need to pass it here. Defaults to `None`. - - add_blank (bool): Add a special `blank` character after every other character. It helps some - models achieve better results. Defaults to false. - return_wav (bool): Return the waveform of the sample. Defaults to False. batch_group_size (int): Range of batch randomization after sorting @@ -82,16 +69,9 @@ class TTSDataset(Dataset): It helps for controlling the VRAM usage against long input sequences. Especially models with RNN layers are sensitive to input length. Defaults to `Inf`. - use_phonemes (bool): If true, input text converted to phonemes. Defaults to false. - phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a separate file. Defaults to None. - phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`. - - enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults - to False. - speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the embedding layer. Defaults to None. @@ -106,7 +86,6 @@ class TTSDataset(Dataset): self.items = meta_data self.outputs_per_step = outputs_per_step self.sample_rate = ap.sample_rate - self.cleaners = text_cleaner self.compute_linear_spec = compute_linear_spec self.return_wav = return_wav self.compute_f0 = compute_f0 @@ -114,13 +93,7 @@ class TTSDataset(Dataset): self.min_seq_len = min_seq_len self.max_seq_len = max_seq_len self.ap = ap - self.characters = characters - self.custom_symbols = custom_symbols - self.add_blank = add_blank - self.use_phonemes = use_phonemes self.phoneme_cache_path = phoneme_cache_path - self.phoneme_language = phoneme_language - self.enable_eos_bos = enable_eos_bos self.speaker_id_mapping = speaker_id_mapping self.d_vector_mapping = d_vector_mapping self.language_id_mapping = language_id_mapping @@ -130,17 +103,23 @@ class TTSDataset(Dataset): self.input_seq_computed = False self.rescue_item_idx = 1 self.pitch_computed = False + self.tokenizer = tokenizer - if use_phonemes and not os.path.isdir(phoneme_cache_path): + if self.tokenizer.use_phonemes and not os.path.isdir(phoneme_cache_path): os.makedirs(phoneme_cache_path, exist_ok=True) if compute_f0: self.pitch_extractor = PitchExtractor(self.items, verbose=verbose) + if self.verbose: - print("\n > DataLoader initialization") - print(" | > Use phonemes: {}".format(self.use_phonemes)) - if use_phonemes: - print(" | > phoneme language: {}".format(phoneme_language)) - print(" | > Number of instances : {}".format(len(self.items))) + self.print_logs() + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> DataLoader initialization") + print(f"{indent}| > Tokenizer:") + self.tokenizer.print_logs(level + 1) + print(f"{indent}| > Number of instances : {len(self.items)}") def load_wav(self, filename): audio = self.ap.load_wav(filename) @@ -152,48 +131,30 @@ class TTSDataset(Dataset): return data @staticmethod - def _generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, custom_symbols, characters, add_blank - ): + def _generate_and_cache_phoneme_sequence(text, tokenizer, cache_path): """generate a phoneme sequence from text. since the usage is for subsequent caching, we never add bos and eos chars here. Instead we add those dynamically later; based on the config option.""" - phonemes = phoneme_to_sequence( - text, - [cleaners], - language=language, - enable_eos_bos=False, - custom_symbols=custom_symbols, - tp=characters, - add_blank=add_blank, - ) + phonemes = tokenizer.text_to_ids(text) phonemes = np.asarray(phonemes, dtype=np.int32) np.save(cache_path, phonemes) return phonemes @staticmethod - def _load_or_generate_phoneme_sequence( - wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank - ): + def _load_or_generate_phoneme_sequence(wav_file, text, language, tokenizer, phoneme_cache_path): file_name = os.path.splitext(os.path.basename(wav_file))[0] # different names for normal phonemes and with blank chars. - file_name_ext = "_blanked_phoneme.npy" if add_blank else "_phoneme.npy" + file_name_ext = "_phoneme.npy" cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext) try: phonemes = np.load(cache_path) except FileNotFoundError: - phonemes = TTSDataset._generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, custom_symbols, characters, add_blank - ) + phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path) except (ValueError, IOError): print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file)) - phonemes = TTSDataset._generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, custom_symbols, characters, add_blank - ) - if enable_eos_bos: - phonemes = pad_with_eos_bos(phonemes, tp=characters) + phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path) phonemes = np.asarray(phonemes, dtype=np.int32) return phonemes @@ -208,27 +169,17 @@ class TTSDataset(Dataset): wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) if not self.input_seq_computed: - if self.use_phonemes: + if self.tokenizer.use_phonemes: text = self._load_or_generate_phoneme_sequence( item["audio_file"], item["text"], - self.phoneme_cache_path, - self.enable_eos_bos, - self.cleaners, item["language"] if item["language"] else self.phoneme_language, - self.custom_symbols, - self.characters, - self.add_blank, + self.tokenizer, + self.phoneme_cache_path, ) else: text = np.asarray( - text_to_sequence( - item["text"], - [self.cleaners], - custom_symbols=self.custom_symbols, - tp=self.characters, - add_blank=self.add_blank, - ), + self.tokenizer.text_to_ids(item["text"], item["language"]), dtype=np.int32, ) @@ -281,24 +232,16 @@ class TTSDataset(Dataset): print(" | > Computing input sequences ...") for idx, item in enumerate(tqdm.tqdm(self.items)): sequence = np.asarray( - text_to_sequence( - item["text"], - [self.cleaners], - custom_symbols=self.custom_symbols, - tp=self.characters, - add_blank=self.add_blank, - ), + self.tokenizer.text_to_ids(item["text"]), dtype=np.int32, ) self.items[idx][0] = sequence - else: func_args = [ self.phoneme_cache_path, self.enable_eos_bos, self.cleaners, self.phoneme_language, - self.custom_symbols, self.characters, self.add_blank, ]