Refactor TTSDataset to use TTSTokenizer

This commit is contained in:
Eren Gölge 2021-11-16 13:33:21 +01:00
parent 2480bbe937
commit e4049aa31a
1 changed files with 28 additions and 85 deletions

View File

@ -10,7 +10,7 @@ import tqdm
from torch.utils.data import Dataset
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
from TTS.tts.utils.text import TTSTokenizer
from TTS.utils.audio import AudioProcessor
@ -18,23 +18,17 @@ class TTSDataset(Dataset):
def __init__(
self,
outputs_per_step: int,
text_cleaner: list,
compute_linear_spec: bool,
ap: AudioProcessor,
meta_data: List[Dict],
tokenizer: "TTSTokenizer" = None,
compute_f0: bool = False,
f0_cache_path: str = None,
characters: Dict = None,
custom_symbols: List = None,
add_blank: bool = False,
return_wav: bool = False,
batch_group_size: int = 0,
min_seq_len: int = 0,
max_seq_len: int = float("inf"),
use_phonemes: bool = False,
phoneme_cache_path: str = None,
phoneme_language: str = "en-us",
enable_eos_bos: bool = False,
speaker_id_mapping: Dict = None,
d_vector_mapping: Dict = None,
language_id_mapping: Dict = None,
@ -48,26 +42,19 @@ class TTSDataset(Dataset):
Args:
outputs_per_step (int): Number of time frames predicted per step.
text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs.
compute_linear_spec (bool): compute linear spectrogram if True.
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
meta_data (list): List of dataset samples.
tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
use the given. Defaults to None.
compute_f0 (bool): compute f0 if True. Defaults to False.
f0_cache_path (str): Path to store f0 cache. Defaults to None.
characters (dict): `dict` of custom text characters used for converting texts to sequences.
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
set of symbols need to pass it here. Defaults to `None`.
add_blank (bool): Add a special `blank` character after every other character. It helps some
models achieve better results. Defaults to false.
return_wav (bool): Return the waveform of the sample. Defaults to False.
batch_group_size (int): Range of batch randomization after sorting
@ -82,16 +69,9 @@ class TTSDataset(Dataset):
It helps for controlling the VRAM usage against long input sequences. Especially models with
RNN layers are sensitive to input length. Defaults to `Inf`.
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
separate file. Defaults to None.
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults
to False.
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
embedding layer. Defaults to None.
@ -106,7 +86,6 @@ class TTSDataset(Dataset):
self.items = meta_data
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav
self.compute_f0 = compute_f0
@ -114,13 +93,7 @@ class TTSDataset(Dataset):
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.ap = ap
self.characters = characters
self.custom_symbols = custom_symbols
self.add_blank = add_blank
self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language
self.enable_eos_bos = enable_eos_bos
self.speaker_id_mapping = speaker_id_mapping
self.d_vector_mapping = d_vector_mapping
self.language_id_mapping = language_id_mapping
@ -130,17 +103,23 @@ class TTSDataset(Dataset):
self.input_seq_computed = False
self.rescue_item_idx = 1
self.pitch_computed = False
self.tokenizer = tokenizer
if use_phonemes and not os.path.isdir(phoneme_cache_path):
if self.tokenizer.use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True)
if compute_f0:
self.pitch_extractor = PitchExtractor(self.items, verbose=verbose)
if self.verbose:
print("\n > DataLoader initialization")
print(" | > Use phonemes: {}".format(self.use_phonemes))
if use_phonemes:
print(" | > phoneme language: {}".format(phoneme_language))
print(" | > Number of instances : {}".format(len(self.items)))
self.print_logs()
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> DataLoader initialization")
print(f"{indent}| > Tokenizer:")
self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.items)}")
def load_wav(self, filename):
audio = self.ap.load_wav(filename)
@ -152,48 +131,30 @@ class TTSDataset(Dataset):
return data
@staticmethod
def _generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
):
def _generate_and_cache_phoneme_sequence(text, tokenizer, cache_path):
"""generate a phoneme sequence from text.
since the usage is for subsequent caching, we never add bos and
eos chars here. Instead we add those dynamically later; based on the
config option."""
phonemes = phoneme_to_sequence(
text,
[cleaners],
language=language,
enable_eos_bos=False,
custom_symbols=custom_symbols,
tp=characters,
add_blank=add_blank,
)
phonemes = tokenizer.text_to_ids(text)
phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes)
return phonemes
@staticmethod
def _load_or_generate_phoneme_sequence(
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank
):
def _load_or_generate_phoneme_sequence(wav_file, text, language, tokenizer, phoneme_cache_path):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
# different names for normal phonemes and with blank chars.
file_name_ext = "_blanked_phoneme.npy" if add_blank else "_phoneme.npy"
file_name_ext = "_phoneme.npy"
cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext)
try:
phonemes = np.load(cache_path)
except FileNotFoundError:
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
)
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path)
except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
)
if enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes, tp=characters)
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path)
phonemes = np.asarray(phonemes, dtype=np.int32)
return phonemes
@ -208,27 +169,17 @@ class TTSDataset(Dataset):
wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
if not self.input_seq_computed:
if self.use_phonemes:
if self.tokenizer.use_phonemes:
text = self._load_or_generate_phoneme_sequence(
item["audio_file"],
item["text"],
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
item["language"] if item["language"] else self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
self.tokenizer,
self.phoneme_cache_path,
)
else:
text = np.asarray(
text_to_sequence(
item["text"],
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
self.tokenizer.text_to_ids(item["text"], item["language"]),
dtype=np.int32,
)
@ -281,24 +232,16 @@ class TTSDataset(Dataset):
print(" | > Computing input sequences ...")
for idx, item in enumerate(tqdm.tqdm(self.items)):
sequence = np.asarray(
text_to_sequence(
item["text"],
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
self.tokenizer.text_to_ids(item["text"]),
dtype=np.int32,
)
self.items[idx][0] = sequence
else:
func_args = [
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
]