Enable `custom_symbols` in text processing

Models can define their own custom symbols lists with custom
`make_symbols()`
This commit is contained in:
Eren Gölge 2021-08-07 21:46:10 +00:00
parent bd4e29b4dd
commit 003e5579e8
5 changed files with 134 additions and 36 deletions

View File

@ -23,7 +23,9 @@ class TTSDataset(Dataset):
ap: AudioProcessor, ap: AudioProcessor,
meta_data: List[List], meta_data: List[List],
characters: Dict = None, characters: Dict = None,
custom_symbols: List = None,
add_blank: bool = False, add_blank: bool = False,
return_wav: bool = False,
batch_group_size: int = 0, batch_group_size: int = 0,
min_seq_len: int = 0, min_seq_len: int = 0,
max_seq_len: int = float("inf"), max_seq_len: int = float("inf"),
@ -54,9 +56,14 @@ class TTSDataset(Dataset):
characters (dict): `dict` of custom text characters used for converting texts to sequences. characters (dict): `dict` of custom text characters used for converting texts to sequences.
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
set of symbols need to pass it here. Defaults to `None`.
add_blank (bool): Add a special `blank` character after every other character. It helps some add_blank (bool): Add a special `blank` character after every other character. It helps some
models achieve better results. Defaults to false. models achieve better results. Defaults to false.
return_wav (bool): Return the waveform of the sample. Defaults to False.
batch_group_size (int): Range of batch randomization after sorting batch_group_size (int): Range of batch randomization after sorting
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
batch. Set 0 to disable. Defaults to 0. batch. Set 0 to disable. Defaults to 0.
@ -95,10 +102,12 @@ class TTSDataset(Dataset):
self.sample_rate = ap.sample_rate self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner self.cleaners = text_cleaner
self.compute_linear_spec = compute_linear_spec self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav
self.min_seq_len = min_seq_len self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.ap = ap self.ap = ap
self.characters = characters self.characters = characters
self.custom_symbols = custom_symbols
self.add_blank = add_blank self.add_blank = add_blank
self.use_phonemes = use_phonemes self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path self.phoneme_cache_path = phoneme_cache_path
@ -109,6 +118,7 @@ class TTSDataset(Dataset):
self.use_noise_augment = use_noise_augment self.use_noise_augment = use_noise_augment
self.verbose = verbose self.verbose = verbose
self.input_seq_computed = False self.input_seq_computed = False
self.rescue_item_idx = 1
if use_phonemes and not os.path.isdir(phoneme_cache_path): if use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True) os.makedirs(phoneme_cache_path, exist_ok=True)
if self.verbose: if self.verbose:
@ -128,13 +138,21 @@ class TTSDataset(Dataset):
return data return data
@staticmethod @staticmethod
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, characters, add_blank): def _generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
):
"""generate a phoneme sequence from text. """generate a phoneme sequence from text.
since the usage is for subsequent caching, we never add bos and since the usage is for subsequent caching, we never add bos and
eos chars here. Instead we add those dynamically later; based on the eos chars here. Instead we add those dynamically later; based on the
config option.""" config option."""
phonemes = phoneme_to_sequence( phonemes = phoneme_to_sequence(
text, [cleaners], language=language, enable_eos_bos=False, tp=characters, add_blank=add_blank text,
[cleaners],
language=language,
enable_eos_bos=False,
custom_symbols=custom_symbols,
tp=characters,
add_blank=add_blank,
) )
phonemes = np.asarray(phonemes, dtype=np.int32) phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes) np.save(cache_path, phonemes)
@ -142,7 +160,7 @@ class TTSDataset(Dataset):
@staticmethod @staticmethod
def _load_or_generate_phoneme_sequence( def _load_or_generate_phoneme_sequence(
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, characters, add_blank wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank
): ):
file_name = os.path.splitext(os.path.basename(wav_file))[0] file_name = os.path.splitext(os.path.basename(wav_file))[0]
@ -153,12 +171,12 @@ class TTSDataset(Dataset):
phonemes = np.load(cache_path) phonemes = np.load(cache_path)
except FileNotFoundError: except FileNotFoundError:
phonemes = TTSDataset._generate_and_cache_phoneme_sequence( phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, characters, add_blank text, cache_path, cleaners, language, custom_symbols, characters, add_blank
) )
except (ValueError, IOError): except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file)) print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = TTSDataset._generate_and_cache_phoneme_sequence( phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, characters, add_blank text, cache_path, cleaners, language, custom_symbols, characters, add_blank
) )
if enable_eos_bos: if enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes, tp=characters) phonemes = pad_with_eos_bos(phonemes, tp=characters)
@ -189,13 +207,19 @@ class TTSDataset(Dataset):
self.enable_eos_bos, self.enable_eos_bos,
self.cleaners, self.cleaners,
self.phoneme_language, self.phoneme_language,
self.custom_symbols,
self.characters, self.characters,
self.add_blank, self.add_blank,
) )
else: else:
text = np.asarray( text = np.asarray(
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank), text_to_sequence(
text,
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32, dtype=np.int32,
) )
@ -209,7 +233,7 @@ class TTSDataset(Dataset):
# return a different sample if the phonemized # return a different sample if the phonemized
# text is longer than the threshold # text is longer than the threshold
# TODO: find a better fix # TODO: find a better fix
return self.load_data(100) return self.load_data(self.rescue_item_idx)
sample = { sample = {
"text": text, "text": text,
@ -238,7 +262,13 @@ class TTSDataset(Dataset):
for idx, item in enumerate(tqdm.tqdm(self.items)): for idx, item in enumerate(tqdm.tqdm(self.items)):
text, *_ = item text, *_ = item
sequence = np.asarray( sequence = np.asarray(
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank), text_to_sequence(
text,
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32, dtype=np.int32,
) )
self.items[idx][0] = sequence self.items[idx][0] = sequence
@ -249,6 +279,7 @@ class TTSDataset(Dataset):
self.enable_eos_bos, self.enable_eos_bos,
self.cleaners, self.cleaners,
self.phoneme_language, self.phoneme_language,
self.custom_symbols,
self.characters, self.characters,
self.add_blank, self.add_blank,
] ]
@ -347,6 +378,14 @@ class TTSDataset(Dataset):
mel_lengths = [m.shape[1] for m in mel] mel_lengths = [m.shape[1] for m in mel]
# lengths adjusted by the reduction factor
mel_lengths_adjusted = [
m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step))
if m.shape[1] % self.outputs_per_step
else m.shape[1]
for m in mel
]
# compute 'stop token' targets # compute 'stop token' targets
stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths] stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths]
@ -385,6 +424,20 @@ class TTSDataset(Dataset):
else: else:
linear = None linear = None
# format waveforms
wav_padded = None
if self.return_wav:
wav_lengths = [w.shape[0] for w in wav]
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
wav_lengths = torch.LongTensor(wav_lengths)
wav_padded = torch.zeros(len(batch), 1, max_wav_len)
for i, w in enumerate(wav):
mel_length = mel_lengths_adjusted[i]
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
w = w[: mel_length * self.ap.hop_length]
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
wav_padded.transpose_(1, 2)
# collate attention alignments # collate attention alignments
if batch[0]["attn"] is not None: if batch[0]["attn"] is not None:
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing] attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
@ -409,6 +462,7 @@ class TTSDataset(Dataset):
d_vectors, d_vectors,
speaker_ids, speaker_ids,
attns, attns,
wav_padded,
) )
raise TypeError( raise TypeError(

View File

@ -15,7 +15,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed:
import tensorflow as tf import tensorflow as tf
def text_to_seq(text, CONFIG): def text_to_seq(text, CONFIG, custom_symbols=None):
text_cleaner = [CONFIG.text_cleaner] text_cleaner = [CONFIG.text_cleaner]
# text ot phonemes to sequence vector # text ot phonemes to sequence vector
if CONFIG.use_phonemes: if CONFIG.use_phonemes:
@ -28,16 +28,14 @@ def text_to_seq(text, CONFIG):
tp=CONFIG.characters, tp=CONFIG.characters,
add_blank=CONFIG.add_blank, add_blank=CONFIG.add_blank,
use_espeak_phonemes=CONFIG.use_espeak_phonemes, use_espeak_phonemes=CONFIG.use_espeak_phonemes,
custom_symbols=custom_symbols,
), ),
dtype=np.int32, dtype=np.int32,
) )
else: else:
seq = np.asarray( seq = np.asarray(
text_to_sequence( text_to_sequence(
text, text, text_cleaner, tp=CONFIG.characters, add_blank=CONFIG.add_blank, custom_symbols=custom_symbols
text_cleaner,
tp=CONFIG.characters,
add_blank=CONFIG.add_blank,
), ),
dtype=np.int32, dtype=np.int32,
) )
@ -229,13 +227,16 @@ def synthesis(
""" """
# GST processing # GST processing
style_mel = None style_mel = None
custom_symbols = None
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None: if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
if isinstance(style_wav, dict): if isinstance(style_wav, dict):
style_mel = style_wav style_mel = style_wav
else: else:
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
if hasattr(model, "make_symbols"):
custom_symbols = model.make_symbols(CONFIG)
# preprocess the given text # preprocess the given text
text_inputs = text_to_seq(text, CONFIG) text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols)
# pass tensors to backend # pass tensors to backend
if backend == "torch": if backend == "torch":
if speaker_id is not None: if speaker_id is not None:
@ -274,15 +275,18 @@ def synthesis(
# convert outputs to numpy # convert outputs to numpy
# plot results # plot results
wav = None wav = None
if use_griffin_lim: if hasattr(model, "END2END") and model.END2END:
wav = inv_spectrogram(model_outputs, ap, CONFIG) wav = model_outputs.squeeze(0)
# trim silence else:
if do_trim_silence: if use_griffin_lim:
wav = trim_silence(wav, ap) wav = inv_spectrogram(model_outputs, ap, CONFIG)
# trim silence
if do_trim_silence:
wav = trim_silence(wav, ap)
return_dict = { return_dict = {
"wav": wav, "wav": wav,
"alignments": alignments, "alignments": alignments,
"model_outputs": model_outputs,
"text_inputs": text_inputs, "text_inputs": text_inputs,
"outputs": outputs,
} }
return return_dict return return_dict

View File

@ -2,10 +2,9 @@
# adapted from https://github.com/keithito/tacotron # adapted from https://github.com/keithito/tacotron
import re import re
import unicodedata from typing import Dict, List
import gruut import gruut
from packaging import version
from TTS.tts.utils.text import cleaners from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
@ -22,6 +21,7 @@ _id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
_symbols = symbols _symbols = symbols
_phonemes = phonemes _phonemes = phonemes
# Regular expression matching text enclosed in curly braces: # Regular expression matching text enclosed in curly braces:
_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)") _CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)")
@ -81,7 +81,7 @@ def text2phone(text, language, use_espeak_phonemes=False):
# Fix a few phonemes # Fix a few phonemes
ph = ph.translate(GRUUT_TRANS_TABLE) ph = ph.translate(GRUUT_TRANS_TABLE)
print(" > Phonemes: {}".format(ph)) # print(" > Phonemes: {}".format(ph))
return ph return ph
raise ValueError(f" [!] Language {language} is not supported for phonemization.") raise ValueError(f" [!] Language {language} is not supported for phonemization.")
@ -106,13 +106,38 @@ def pad_with_eos_bos(phoneme_sequence, tp=None):
def phoneme_to_sequence( def phoneme_to_sequence(
text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False, use_espeak_phonemes=False text: str,
): cleaner_names: List[str],
language: str,
enable_eos_bos: bool = False,
custom_symbols: List[str] = None,
tp: Dict = None,
add_blank: bool = False,
use_espeak_phonemes: bool = False,
) -> List[int]:
"""Converts a string of phonemes to a sequence of IDs.
Args:
text (str): string to convert to a sequence
cleaner_names (List[str]): names of the cleaner functions to run the text through
language (str): text language key for phonemization.
enable_eos_bos (bool): whether to append the end-of-sentence and beginning-of-sentence tokens.
tp (Dict): dictionary of character parameters to use a custom character set.
add_blank (bool): option to add a blank token between each token.
use_espeak_phonemes (bool): use espeak based lexicons to convert phonemes to sequenc
Returns:
List[int]: List of integers corresponding to the symbols in the text
"""
# pylint: disable=global-statement # pylint: disable=global-statement
global _phonemes_to_id, _phonemes global _phonemes_to_id, _phonemes
if tp: if tp:
_, _phonemes = make_symbols(**tp) _, _phonemes = make_symbols(**tp)
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
elif custom_symbols is not None:
_phonemes = custom_symbols
_phonemes_to_id = {s: i for i, s in enumerate(custom_symbols)}
sequence = [] sequence = []
clean_text = _clean_text(text, cleaner_names) clean_text = _clean_text(text, cleaner_names)
@ -127,7 +152,6 @@ def phoneme_to_sequence(
sequence = pad_with_eos_bos(sequence, tp=tp) sequence = pad_with_eos_bos(sequence, tp=tp)
if add_blank: if add_blank:
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes) sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
return sequence return sequence
@ -149,27 +173,31 @@ def sequence_to_phoneme(sequence, tp=None, add_blank=False):
return result.replace("}{", " ") return result.replace("}{", " ")
def text_to_sequence(text, cleaner_names, tp=None, add_blank=False): def text_to_sequence(
text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False
) -> List[int]:
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text. """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args: Args:
text: string to convert to a sequence text (str): string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through cleaner_names (List[str]): names of the cleaner functions to run the text through
tp: dictionary of character parameters to use a custom character set. tp (Dict): dictionary of character parameters to use a custom character set.
add_blank (bool): option to add a blank token between each token.
Returns: Returns:
List of integers corresponding to the symbols in the text List[int]: List of integers corresponding to the symbols in the text
""" """
# pylint: disable=global-statement # pylint: disable=global-statement
global _symbol_to_id, _symbols global _symbol_to_id, _symbols
if tp: if tp:
_symbols, _ = make_symbols(**tp) _symbols, _ = make_symbols(**tp)
_symbol_to_id = {s: i for i, s in enumerate(_symbols)} _symbol_to_id = {s: i for i, s in enumerate(_symbols)}
elif custom_symbols is not None:
_symbols = custom_symbols
_symbol_to_id = {s: i for i, s in enumerate(custom_symbols)}
sequence = [] sequence = []
# Check for curly braces and treat their contents as ARPAbet: # Check for curly braces and treat their contents as ARPAbet:
while text: while text:
m = _CURLY_RE.match(text) m = _CURLY_RE.match(text)

View File

@ -42,6 +42,7 @@ class TestTTSDataset(unittest.TestCase):
r, r,
c.text_cleaner, c.text_cleaner,
compute_linear_spec=True, compute_linear_spec=True,
return_wav=True,
ap=self.ap, ap=self.ap,
meta_data=items, meta_data=items,
characters=c.characters, characters=c.characters,
@ -75,16 +76,26 @@ class TestTTSDataset(unittest.TestCase):
mel_lengths = data[5] mel_lengths = data[5]
stop_target = data[6] stop_target = data[6]
item_idx = data[7] item_idx = data[7]
wavs = data[11]
neg_values = text_input[text_input < 0] neg_values = text_input[text_input < 0]
check_count = len(neg_values) check_count = len(neg_values)
assert check_count == 0, " !! Negative values in text_input: {}".format(check_count) assert check_count == 0, " !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert isinstance(speaker_name[0], str) assert isinstance(speaker_name[0], str)
assert linear_input.shape[0] == c.batch_size assert linear_input.shape[0] == c.batch_size
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1 assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
assert mel_input.shape[0] == c.batch_size assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.audio["num_mels"] assert mel_input.shape[2] == c.audio["num_mels"]
assert (
wavs.shape[1] == mel_input.shape[1] * c.audio.hop_length
), f"wavs.shape: {wavs.shape[1]}, mel_input.shape: {mel_input.shape[1] * c.audio.hop_length}"
# make sure that the computed mels and the waveform match and correctly computed
mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy())
ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length)
mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg]
assert abs(mel_diff.sum()) < 1e-5
# check normalization ranges # check normalization ranges
if self.ap.symmetric_norm: if self.ap.symmetric_norm:
assert mel_input.max() <= self.ap.max_norm assert mel_input.max() <= self.ap.max_norm

View File

@ -27,6 +27,7 @@ config = AlignTTSConfig(
"Be a voice, not an echo.", "Be a voice, not an echo.",
], ],
) )
config.audio.do_trim_silence = True config.audio.do_trim_silence = True
config.audio.trim_db = 60 config.audio.trim_db = 60
config.save_json(config_path) config.save_json(config_path)