From 003e5579e82ff9e7f5b23bb70ee0cd6ff1e579de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 7 Aug 2021 21:46:10 +0000 Subject: [PATCH] Enable `custom_symbols` in text processing Models can define their own custom symbols lists with custom `make_symbols()` --- TTS/tts/datasets/TTSDataset.py | 72 +++++++++++++++++++++---- TTS/tts/utils/synthesis.py | 28 +++++----- TTS/tts/utils/text/__init__.py | 56 ++++++++++++++----- tests/data_tests/test_loader.py | 13 ++++- tests/tts_tests/test_align_tts_train.py | 1 + 5 files changed, 134 insertions(+), 36 deletions(-) diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 0fc23231..aaa0ba50 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -23,7 +23,9 @@ class TTSDataset(Dataset): ap: AudioProcessor, meta_data: List[List], characters: Dict = None, + custom_symbols: List = None, add_blank: bool = False, + return_wav: bool = False, batch_group_size: int = 0, min_seq_len: int = 0, max_seq_len: int = float("inf"), @@ -54,9 +56,14 @@ class TTSDataset(Dataset): characters (dict): `dict` of custom text characters used for converting texts to sequences. + custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own + set of symbols need to pass it here. Defaults to `None`. + add_blank (bool): Add a special `blank` character after every other character. It helps some models achieve better results. Defaults to false. + return_wav (bool): Return the waveform of the sample. Defaults to False. + batch_group_size (int): Range of batch randomization after sorting sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a batch. Set 0 to disable. Defaults to 0. @@ -95,10 +102,12 @@ class TTSDataset(Dataset): self.sample_rate = ap.sample_rate self.cleaners = text_cleaner self.compute_linear_spec = compute_linear_spec + self.return_wav = return_wav self.min_seq_len = min_seq_len self.max_seq_len = max_seq_len self.ap = ap self.characters = characters + self.custom_symbols = custom_symbols self.add_blank = add_blank self.use_phonemes = use_phonemes self.phoneme_cache_path = phoneme_cache_path @@ -109,6 +118,7 @@ class TTSDataset(Dataset): self.use_noise_augment = use_noise_augment self.verbose = verbose self.input_seq_computed = False + self.rescue_item_idx = 1 if use_phonemes and not os.path.isdir(phoneme_cache_path): os.makedirs(phoneme_cache_path, exist_ok=True) if self.verbose: @@ -128,13 +138,21 @@ class TTSDataset(Dataset): return data @staticmethod - def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, characters, add_blank): + def _generate_and_cache_phoneme_sequence( + text, cache_path, cleaners, language, custom_symbols, characters, add_blank + ): """generate a phoneme sequence from text. since the usage is for subsequent caching, we never add bos and eos chars here. Instead we add those dynamically later; based on the config option.""" phonemes = phoneme_to_sequence( - text, [cleaners], language=language, enable_eos_bos=False, tp=characters, add_blank=add_blank + text, + [cleaners], + language=language, + enable_eos_bos=False, + custom_symbols=custom_symbols, + tp=characters, + add_blank=add_blank, ) phonemes = np.asarray(phonemes, dtype=np.int32) np.save(cache_path, phonemes) @@ -142,7 +160,7 @@ class TTSDataset(Dataset): @staticmethod def _load_or_generate_phoneme_sequence( - wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, characters, add_blank + wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank ): file_name = os.path.splitext(os.path.basename(wav_file))[0] @@ -153,12 +171,12 @@ class TTSDataset(Dataset): phonemes = np.load(cache_path) except FileNotFoundError: phonemes = TTSDataset._generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, characters, add_blank + text, cache_path, cleaners, language, custom_symbols, characters, add_blank ) except (ValueError, IOError): print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file)) phonemes = TTSDataset._generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, characters, add_blank + text, cache_path, cleaners, language, custom_symbols, characters, add_blank ) if enable_eos_bos: phonemes = pad_with_eos_bos(phonemes, tp=characters) @@ -189,13 +207,19 @@ class TTSDataset(Dataset): self.enable_eos_bos, self.cleaners, self.phoneme_language, + self.custom_symbols, self.characters, self.add_blank, ) - else: text = np.asarray( - text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank), + text_to_sequence( + text, + [self.cleaners], + custom_symbols=self.custom_symbols, + tp=self.characters, + add_blank=self.add_blank, + ), dtype=np.int32, ) @@ -209,7 +233,7 @@ class TTSDataset(Dataset): # return a different sample if the phonemized # text is longer than the threshold # TODO: find a better fix - return self.load_data(100) + return self.load_data(self.rescue_item_idx) sample = { "text": text, @@ -238,7 +262,13 @@ class TTSDataset(Dataset): for idx, item in enumerate(tqdm.tqdm(self.items)): text, *_ = item sequence = np.asarray( - text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank), + text_to_sequence( + text, + [self.cleaners], + custom_symbols=self.custom_symbols, + tp=self.characters, + add_blank=self.add_blank, + ), dtype=np.int32, ) self.items[idx][0] = sequence @@ -249,6 +279,7 @@ class TTSDataset(Dataset): self.enable_eos_bos, self.cleaners, self.phoneme_language, + self.custom_symbols, self.characters, self.add_blank, ] @@ -347,6 +378,14 @@ class TTSDataset(Dataset): mel_lengths = [m.shape[1] for m in mel] + # lengths adjusted by the reduction factor + mel_lengths_adjusted = [ + m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step)) + if m.shape[1] % self.outputs_per_step + else m.shape[1] + for m in mel + ] + # compute 'stop token' targets stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths] @@ -385,6 +424,20 @@ class TTSDataset(Dataset): else: linear = None + # format waveforms + wav_padded = None + if self.return_wav: + wav_lengths = [w.shape[0] for w in wav] + max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length + wav_lengths = torch.LongTensor(wav_lengths) + wav_padded = torch.zeros(len(batch), 1, max_wav_len) + for i, w in enumerate(wav): + mel_length = mel_lengths_adjusted[i] + w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge") + w = w[: mel_length * self.ap.hop_length] + wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) + wav_padded.transpose_(1, 2) + # collate attention alignments if batch[0]["attn"] is not None: attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing] @@ -409,6 +462,7 @@ class TTSDataset(Dataset): d_vectors, speaker_ids, attns, + wav_padded, ) raise TypeError( diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 39474cab..ca15f4cc 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -15,7 +15,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed: import tensorflow as tf -def text_to_seq(text, CONFIG): +def text_to_seq(text, CONFIG, custom_symbols=None): text_cleaner = [CONFIG.text_cleaner] # text ot phonemes to sequence vector if CONFIG.use_phonemes: @@ -28,16 +28,14 @@ def text_to_seq(text, CONFIG): tp=CONFIG.characters, add_blank=CONFIG.add_blank, use_espeak_phonemes=CONFIG.use_espeak_phonemes, + custom_symbols=custom_symbols, ), dtype=np.int32, ) else: seq = np.asarray( text_to_sequence( - text, - text_cleaner, - tp=CONFIG.characters, - add_blank=CONFIG.add_blank, + text, text_cleaner, tp=CONFIG.characters, add_blank=CONFIG.add_blank, custom_symbols=custom_symbols ), dtype=np.int32, ) @@ -229,13 +227,16 @@ def synthesis( """ # GST processing style_mel = None + custom_symbols = None if CONFIG.has("gst") and CONFIG.gst and style_wav is not None: if isinstance(style_wav, dict): style_mel = style_wav else: style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) + if hasattr(model, "make_symbols"): + custom_symbols = model.make_symbols(CONFIG) # preprocess the given text - text_inputs = text_to_seq(text, CONFIG) + text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols) # pass tensors to backend if backend == "torch": if speaker_id is not None: @@ -274,15 +275,18 @@ def synthesis( # convert outputs to numpy # plot results wav = None - if use_griffin_lim: - wav = inv_spectrogram(model_outputs, ap, CONFIG) - # trim silence - if do_trim_silence: - wav = trim_silence(wav, ap) + if hasattr(model, "END2END") and model.END2END: + wav = model_outputs.squeeze(0) + else: + if use_griffin_lim: + wav = inv_spectrogram(model_outputs, ap, CONFIG) + # trim silence + if do_trim_silence: + wav = trim_silence(wav, ap) return_dict = { "wav": wav, "alignments": alignments, - "model_outputs": model_outputs, "text_inputs": text_inputs, + "outputs": outputs, } return return_dict diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py index fdccf7f1..48f69374 100644 --- a/TTS/tts/utils/text/__init__.py +++ b/TTS/tts/utils/text/__init__.py @@ -2,10 +2,9 @@ # adapted from https://github.com/keithito/tacotron import re -import unicodedata +from typing import Dict, List import gruut -from packaging import version from TTS.tts.utils.text import cleaners from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes @@ -22,6 +21,7 @@ _id_to_phonemes = {i: s for i, s in enumerate(phonemes)} _symbols = symbols _phonemes = phonemes + # Regular expression matching text enclosed in curly braces: _CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)") @@ -81,7 +81,7 @@ def text2phone(text, language, use_espeak_phonemes=False): # Fix a few phonemes ph = ph.translate(GRUUT_TRANS_TABLE) - print(" > Phonemes: {}".format(ph)) + # print(" > Phonemes: {}".format(ph)) return ph raise ValueError(f" [!] Language {language} is not supported for phonemization.") @@ -106,13 +106,38 @@ def pad_with_eos_bos(phoneme_sequence, tp=None): def phoneme_to_sequence( - text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False, use_espeak_phonemes=False -): + text: str, + cleaner_names: List[str], + language: str, + enable_eos_bos: bool = False, + custom_symbols: List[str] = None, + tp: Dict = None, + add_blank: bool = False, + use_espeak_phonemes: bool = False, +) -> List[int]: + """Converts a string of phonemes to a sequence of IDs. + + Args: + text (str): string to convert to a sequence + cleaner_names (List[str]): names of the cleaner functions to run the text through + language (str): text language key for phonemization. + enable_eos_bos (bool): whether to append the end-of-sentence and beginning-of-sentence tokens. + tp (Dict): dictionary of character parameters to use a custom character set. + add_blank (bool): option to add a blank token between each token. + use_espeak_phonemes (bool): use espeak based lexicons to convert phonemes to sequenc + + Returns: + List[int]: List of integers corresponding to the symbols in the text + """ # pylint: disable=global-statement global _phonemes_to_id, _phonemes + if tp: _, _phonemes = make_symbols(**tp) _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} + elif custom_symbols is not None: + _phonemes = custom_symbols + _phonemes_to_id = {s: i for i, s in enumerate(custom_symbols)} sequence = [] clean_text = _clean_text(text, cleaner_names) @@ -127,7 +152,6 @@ def phoneme_to_sequence( sequence = pad_with_eos_bos(sequence, tp=tp) if add_blank: sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes) - return sequence @@ -149,27 +173,31 @@ def sequence_to_phoneme(sequence, tp=None, add_blank=False): return result.replace("}{", " ") -def text_to_sequence(text, cleaner_names, tp=None, add_blank=False): +def text_to_sequence( + text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False +) -> List[int]: """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. - The text can optionally have ARPAbet sequences enclosed in curly braces embedded - in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." - Args: - text: string to convert to a sequence - cleaner_names: names of the cleaner functions to run the text through - tp: dictionary of character parameters to use a custom character set. + text (str): string to convert to a sequence + cleaner_names (List[str]): names of the cleaner functions to run the text through + tp (Dict): dictionary of character parameters to use a custom character set. + add_blank (bool): option to add a blank token between each token. Returns: - List of integers corresponding to the symbols in the text + List[int]: List of integers corresponding to the symbols in the text """ # pylint: disable=global-statement global _symbol_to_id, _symbols if tp: _symbols, _ = make_symbols(**tp) _symbol_to_id = {s: i for i, s in enumerate(_symbols)} + elif custom_symbols is not None: + _symbols = custom_symbols + _symbol_to_id = {s: i for i, s in enumerate(custom_symbols)} sequence = [] + # Check for curly braces and treat their contents as ARPAbet: while text: m = _CURLY_RE.match(text) diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 3fd3eaef..10067094 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -42,6 +42,7 @@ class TestTTSDataset(unittest.TestCase): r, c.text_cleaner, compute_linear_spec=True, + return_wav=True, ap=self.ap, meta_data=items, characters=c.characters, @@ -75,16 +76,26 @@ class TestTTSDataset(unittest.TestCase): mel_lengths = data[5] stop_target = data[6] item_idx = data[7] + wavs = data[11] neg_values = text_input[text_input < 0] check_count = len(neg_values) assert check_count == 0, " !! Negative values in text_input: {}".format(check_count) - # TODO: more assertion here assert isinstance(speaker_name[0], str) assert linear_input.shape[0] == c.batch_size assert linear_input.shape[2] == self.ap.fft_size // 2 + 1 assert mel_input.shape[0] == c.batch_size assert mel_input.shape[2] == c.audio["num_mels"] + assert ( + wavs.shape[1] == mel_input.shape[1] * c.audio.hop_length + ), f"wavs.shape: {wavs.shape[1]}, mel_input.shape: {mel_input.shape[1] * c.audio.hop_length}" + + # make sure that the computed mels and the waveform match and correctly computed + mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy()) + ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length) + mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg] + assert abs(mel_diff.sum()) < 1e-5 + # check normalization ranges if self.ap.symmetric_norm: assert mel_input.max() <= self.ap.max_norm diff --git a/tests/tts_tests/test_align_tts_train.py b/tests/tts_tests/test_align_tts_train.py index 3700b1d3..f04a2358 100644 --- a/tests/tts_tests/test_align_tts_train.py +++ b/tests/tts_tests/test_align_tts_train.py @@ -27,6 +27,7 @@ config = AlignTTSConfig( "Be a voice, not an echo.", ], ) + config.audio.do_trim_silence = True config.audio.trim_db = 60 config.save_json(config_path)