Refactor TTSDataset to use TTSTokenizer

This commit is contained in:
Eren Gölge 2021-11-16 13:33:21 +01:00
parent 2480bbe937
commit e4049aa31a
1 changed files with 28 additions and 85 deletions

View File

@ -10,7 +10,7 @@ import tqdm
from torch.utils.data import Dataset from torch.utils.data import Dataset
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence from TTS.tts.utils.text import TTSTokenizer
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -18,23 +18,17 @@ class TTSDataset(Dataset):
def __init__( def __init__(
self, self,
outputs_per_step: int, outputs_per_step: int,
text_cleaner: list,
compute_linear_spec: bool, compute_linear_spec: bool,
ap: AudioProcessor, ap: AudioProcessor,
meta_data: List[Dict], meta_data: List[Dict],
tokenizer: "TTSTokenizer" = None,
compute_f0: bool = False, compute_f0: bool = False,
f0_cache_path: str = None, f0_cache_path: str = None,
characters: Dict = None,
custom_symbols: List = None,
add_blank: bool = False,
return_wav: bool = False, return_wav: bool = False,
batch_group_size: int = 0, batch_group_size: int = 0,
min_seq_len: int = 0, min_seq_len: int = 0,
max_seq_len: int = float("inf"), max_seq_len: int = float("inf"),
use_phonemes: bool = False,
phoneme_cache_path: str = None, phoneme_cache_path: str = None,
phoneme_language: str = "en-us",
enable_eos_bos: bool = False,
speaker_id_mapping: Dict = None, speaker_id_mapping: Dict = None,
d_vector_mapping: Dict = None, d_vector_mapping: Dict = None,
language_id_mapping: Dict = None, language_id_mapping: Dict = None,
@ -48,26 +42,19 @@ class TTSDataset(Dataset):
Args: Args:
outputs_per_step (int): Number of time frames predicted per step. outputs_per_step (int): Number of time frames predicted per step.
text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs.
compute_linear_spec (bool): compute linear spectrogram if True. compute_linear_spec (bool): compute linear spectrogram if True.
ap (TTS.tts.utils.AudioProcessor): Audio processor object. ap (TTS.tts.utils.AudioProcessor): Audio processor object.
meta_data (list): List of dataset samples. meta_data (list): List of dataset samples.
tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
use the given. Defaults to None.
compute_f0 (bool): compute f0 if True. Defaults to False. compute_f0 (bool): compute f0 if True. Defaults to False.
f0_cache_path (str): Path to store f0 cache. Defaults to None. f0_cache_path (str): Path to store f0 cache. Defaults to None.
characters (dict): `dict` of custom text characters used for converting texts to sequences.
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
set of symbols need to pass it here. Defaults to `None`.
add_blank (bool): Add a special `blank` character after every other character. It helps some
models achieve better results. Defaults to false.
return_wav (bool): Return the waveform of the sample. Defaults to False. return_wav (bool): Return the waveform of the sample. Defaults to False.
batch_group_size (int): Range of batch randomization after sorting batch_group_size (int): Range of batch randomization after sorting
@ -82,16 +69,9 @@ class TTSDataset(Dataset):
It helps for controlling the VRAM usage against long input sequences. Especially models with It helps for controlling the VRAM usage against long input sequences. Especially models with
RNN layers are sensitive to input length. Defaults to `Inf`. RNN layers are sensitive to input length. Defaults to `Inf`.
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
separate file. Defaults to None. separate file. Defaults to None.
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults
to False.
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
embedding layer. Defaults to None. embedding layer. Defaults to None.
@ -106,7 +86,6 @@ class TTSDataset(Dataset):
self.items = meta_data self.items = meta_data
self.outputs_per_step = outputs_per_step self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.compute_linear_spec = compute_linear_spec self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav self.return_wav = return_wav
self.compute_f0 = compute_f0 self.compute_f0 = compute_f0
@ -114,13 +93,7 @@ class TTSDataset(Dataset):
self.min_seq_len = min_seq_len self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.ap = ap self.ap = ap
self.characters = characters
self.custom_symbols = custom_symbols
self.add_blank = add_blank
self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language
self.enable_eos_bos = enable_eos_bos
self.speaker_id_mapping = speaker_id_mapping self.speaker_id_mapping = speaker_id_mapping
self.d_vector_mapping = d_vector_mapping self.d_vector_mapping = d_vector_mapping
self.language_id_mapping = language_id_mapping self.language_id_mapping = language_id_mapping
@ -130,17 +103,23 @@ class TTSDataset(Dataset):
self.input_seq_computed = False self.input_seq_computed = False
self.rescue_item_idx = 1 self.rescue_item_idx = 1
self.pitch_computed = False self.pitch_computed = False
self.tokenizer = tokenizer
if use_phonemes and not os.path.isdir(phoneme_cache_path): if self.tokenizer.use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True) os.makedirs(phoneme_cache_path, exist_ok=True)
if compute_f0: if compute_f0:
self.pitch_extractor = PitchExtractor(self.items, verbose=verbose) self.pitch_extractor = PitchExtractor(self.items, verbose=verbose)
if self.verbose: if self.verbose:
print("\n > DataLoader initialization") self.print_logs()
print(" | > Use phonemes: {}".format(self.use_phonemes))
if use_phonemes: def print_logs(self, level: int = 0) -> None:
print(" | > phoneme language: {}".format(phoneme_language)) indent = "\t" * level
print(" | > Number of instances : {}".format(len(self.items))) print("\n")
print(f"{indent}> DataLoader initialization")
print(f"{indent}| > Tokenizer:")
self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.items)}")
def load_wav(self, filename): def load_wav(self, filename):
audio = self.ap.load_wav(filename) audio = self.ap.load_wav(filename)
@ -152,48 +131,30 @@ class TTSDataset(Dataset):
return data return data
@staticmethod @staticmethod
def _generate_and_cache_phoneme_sequence( def _generate_and_cache_phoneme_sequence(text, tokenizer, cache_path):
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
):
"""generate a phoneme sequence from text. """generate a phoneme sequence from text.
since the usage is for subsequent caching, we never add bos and since the usage is for subsequent caching, we never add bos and
eos chars here. Instead we add those dynamically later; based on the eos chars here. Instead we add those dynamically later; based on the
config option.""" config option."""
phonemes = phoneme_to_sequence( phonemes = tokenizer.text_to_ids(text)
text,
[cleaners],
language=language,
enable_eos_bos=False,
custom_symbols=custom_symbols,
tp=characters,
add_blank=add_blank,
)
phonemes = np.asarray(phonemes, dtype=np.int32) phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes) np.save(cache_path, phonemes)
return phonemes return phonemes
@staticmethod @staticmethod
def _load_or_generate_phoneme_sequence( def _load_or_generate_phoneme_sequence(wav_file, text, language, tokenizer, phoneme_cache_path):
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank
):
file_name = os.path.splitext(os.path.basename(wav_file))[0] file_name = os.path.splitext(os.path.basename(wav_file))[0]
# different names for normal phonemes and with blank chars. # different names for normal phonemes and with blank chars.
file_name_ext = "_blanked_phoneme.npy" if add_blank else "_phoneme.npy" file_name_ext = "_phoneme.npy"
cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext) cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext)
try: try:
phonemes = np.load(cache_path) phonemes = np.load(cache_path)
except FileNotFoundError: except FileNotFoundError:
phonemes = TTSDataset._generate_and_cache_phoneme_sequence( phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path)
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
)
except (ValueError, IOError): except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file)) print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = TTSDataset._generate_and_cache_phoneme_sequence( phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path)
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
)
if enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes, tp=characters)
phonemes = np.asarray(phonemes, dtype=np.int32) phonemes = np.asarray(phonemes, dtype=np.int32)
return phonemes return phonemes
@ -208,27 +169,17 @@ class TTSDataset(Dataset):
wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
if not self.input_seq_computed: if not self.input_seq_computed:
if self.use_phonemes: if self.tokenizer.use_phonemes:
text = self._load_or_generate_phoneme_sequence( text = self._load_or_generate_phoneme_sequence(
item["audio_file"], item["audio_file"],
item["text"], item["text"],
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
item["language"] if item["language"] else self.phoneme_language, item["language"] if item["language"] else self.phoneme_language,
self.custom_symbols, self.tokenizer,
self.characters, self.phoneme_cache_path,
self.add_blank,
) )
else: else:
text = np.asarray( text = np.asarray(
text_to_sequence( self.tokenizer.text_to_ids(item["text"], item["language"]),
item["text"],
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32, dtype=np.int32,
) )
@ -281,24 +232,16 @@ class TTSDataset(Dataset):
print(" | > Computing input sequences ...") print(" | > Computing input sequences ...")
for idx, item in enumerate(tqdm.tqdm(self.items)): for idx, item in enumerate(tqdm.tqdm(self.items)):
sequence = np.asarray( sequence = np.asarray(
text_to_sequence( self.tokenizer.text_to_ids(item["text"]),
item["text"],
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32, dtype=np.int32,
) )
self.items[idx][0] = sequence self.items[idx][0] = sequence
else: else:
func_args = [ func_args = [
self.phoneme_cache_path, self.phoneme_cache_path,
self.enable_eos_bos, self.enable_eos_bos,
self.cleaners, self.cleaners,
self.phoneme_language, self.phoneme_language,
self.custom_symbols,
self.characters, self.characters,
self.add_blank, self.add_blank,
] ]