mirror of https://github.com/coqui-ai/TTS.git
Refactor TTSDataset to use TTSTokenizer
This commit is contained in:
parent
2480bbe937
commit
e4049aa31a
|
@ -10,7 +10,7 @@ import tqdm
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
||||
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
|
||||
from TTS.tts.utils.text import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
|
@ -18,23 +18,17 @@ class TTSDataset(Dataset):
|
|||
def __init__(
|
||||
self,
|
||||
outputs_per_step: int,
|
||||
text_cleaner: list,
|
||||
compute_linear_spec: bool,
|
||||
ap: AudioProcessor,
|
||||
meta_data: List[Dict],
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
compute_f0: bool = False,
|
||||
f0_cache_path: str = None,
|
||||
characters: Dict = None,
|
||||
custom_symbols: List = None,
|
||||
add_blank: bool = False,
|
||||
return_wav: bool = False,
|
||||
batch_group_size: int = 0,
|
||||
min_seq_len: int = 0,
|
||||
max_seq_len: int = float("inf"),
|
||||
use_phonemes: bool = False,
|
||||
phoneme_cache_path: str = None,
|
||||
phoneme_language: str = "en-us",
|
||||
enable_eos_bos: bool = False,
|
||||
speaker_id_mapping: Dict = None,
|
||||
d_vector_mapping: Dict = None,
|
||||
language_id_mapping: Dict = None,
|
||||
|
@ -48,26 +42,19 @@ class TTSDataset(Dataset):
|
|||
Args:
|
||||
outputs_per_step (int): Number of time frames predicted per step.
|
||||
|
||||
text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs.
|
||||
|
||||
compute_linear_spec (bool): compute linear spectrogram if True.
|
||||
|
||||
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
|
||||
|
||||
meta_data (list): List of dataset samples.
|
||||
|
||||
tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
|
||||
use the given. Defaults to None.
|
||||
|
||||
compute_f0 (bool): compute f0 if True. Defaults to False.
|
||||
|
||||
f0_cache_path (str): Path to store f0 cache. Defaults to None.
|
||||
|
||||
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
||||
|
||||
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
|
||||
set of symbols need to pass it here. Defaults to `None`.
|
||||
|
||||
add_blank (bool): Add a special `blank` character after every other character. It helps some
|
||||
models achieve better results. Defaults to false.
|
||||
|
||||
return_wav (bool): Return the waveform of the sample. Defaults to False.
|
||||
|
||||
batch_group_size (int): Range of batch randomization after sorting
|
||||
|
@ -82,16 +69,9 @@ class TTSDataset(Dataset):
|
|||
It helps for controlling the VRAM usage against long input sequences. Especially models with
|
||||
RNN layers are sensitive to input length. Defaults to `Inf`.
|
||||
|
||||
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
|
||||
|
||||
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
|
||||
separate file. Defaults to None.
|
||||
|
||||
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
|
||||
|
||||
enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults
|
||||
to False.
|
||||
|
||||
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
|
||||
embedding layer. Defaults to None.
|
||||
|
||||
|
@ -106,7 +86,6 @@ class TTSDataset(Dataset):
|
|||
self.items = meta_data
|
||||
self.outputs_per_step = outputs_per_step
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.compute_linear_spec = compute_linear_spec
|
||||
self.return_wav = return_wav
|
||||
self.compute_f0 = compute_f0
|
||||
|
@ -114,13 +93,7 @@ class TTSDataset(Dataset):
|
|||
self.min_seq_len = min_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
self.ap = ap
|
||||
self.characters = characters
|
||||
self.custom_symbols = custom_symbols
|
||||
self.add_blank = add_blank
|
||||
self.use_phonemes = use_phonemes
|
||||
self.phoneme_cache_path = phoneme_cache_path
|
||||
self.phoneme_language = phoneme_language
|
||||
self.enable_eos_bos = enable_eos_bos
|
||||
self.speaker_id_mapping = speaker_id_mapping
|
||||
self.d_vector_mapping = d_vector_mapping
|
||||
self.language_id_mapping = language_id_mapping
|
||||
|
@ -130,17 +103,23 @@ class TTSDataset(Dataset):
|
|||
self.input_seq_computed = False
|
||||
self.rescue_item_idx = 1
|
||||
self.pitch_computed = False
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||
if self.tokenizer.use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||
if compute_f0:
|
||||
self.pitch_extractor = PitchExtractor(self.items, verbose=verbose)
|
||||
|
||||
if self.verbose:
|
||||
print("\n > DataLoader initialization")
|
||||
print(" | > Use phonemes: {}".format(self.use_phonemes))
|
||||
if use_phonemes:
|
||||
print(" | > phoneme language: {}".format(phoneme_language))
|
||||
print(" | > Number of instances : {}".format(len(self.items)))
|
||||
self.print_logs()
|
||||
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
print("\n")
|
||||
print(f"{indent}> DataLoader initialization")
|
||||
print(f"{indent}| > Tokenizer:")
|
||||
self.tokenizer.print_logs(level + 1)
|
||||
print(f"{indent}| > Number of instances : {len(self.items)}")
|
||||
|
||||
def load_wav(self, filename):
|
||||
audio = self.ap.load_wav(filename)
|
||||
|
@ -152,48 +131,30 @@ class TTSDataset(Dataset):
|
|||
return data
|
||||
|
||||
@staticmethod
|
||||
def _generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||
):
|
||||
def _generate_and_cache_phoneme_sequence(text, tokenizer, cache_path):
|
||||
"""generate a phoneme sequence from text.
|
||||
since the usage is for subsequent caching, we never add bos and
|
||||
eos chars here. Instead we add those dynamically later; based on the
|
||||
config option."""
|
||||
phonemes = phoneme_to_sequence(
|
||||
text,
|
||||
[cleaners],
|
||||
language=language,
|
||||
enable_eos_bos=False,
|
||||
custom_symbols=custom_symbols,
|
||||
tp=characters,
|
||||
add_blank=add_blank,
|
||||
)
|
||||
phonemes = tokenizer.text_to_ids(text)
|
||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||
np.save(cache_path, phonemes)
|
||||
return phonemes
|
||||
|
||||
@staticmethod
|
||||
def _load_or_generate_phoneme_sequence(
|
||||
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank
|
||||
):
|
||||
def _load_or_generate_phoneme_sequence(wav_file, text, language, tokenizer, phoneme_cache_path):
|
||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
|
||||
# different names for normal phonemes and with blank chars.
|
||||
file_name_ext = "_blanked_phoneme.npy" if add_blank else "_phoneme.npy"
|
||||
file_name_ext = "_phoneme.npy"
|
||||
cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext)
|
||||
try:
|
||||
phonemes = np.load(cache_path)
|
||||
except FileNotFoundError:
|
||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||
)
|
||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path)
|
||||
except (ValueError, IOError):
|
||||
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||
)
|
||||
if enable_eos_bos:
|
||||
phonemes = pad_with_eos_bos(phonemes, tp=characters)
|
||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path)
|
||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||
return phonemes
|
||||
|
||||
|
@ -208,27 +169,17 @@ class TTSDataset(Dataset):
|
|||
wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
|
||||
|
||||
if not self.input_seq_computed:
|
||||
if self.use_phonemes:
|
||||
if self.tokenizer.use_phonemes:
|
||||
text = self._load_or_generate_phoneme_sequence(
|
||||
item["audio_file"],
|
||||
item["text"],
|
||||
self.phoneme_cache_path,
|
||||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
item["language"] if item["language"] else self.phoneme_language,
|
||||
self.custom_symbols,
|
||||
self.characters,
|
||||
self.add_blank,
|
||||
self.tokenizer,
|
||||
self.phoneme_cache_path,
|
||||
)
|
||||
else:
|
||||
text = np.asarray(
|
||||
text_to_sequence(
|
||||
item["text"],
|
||||
[self.cleaners],
|
||||
custom_symbols=self.custom_symbols,
|
||||
tp=self.characters,
|
||||
add_blank=self.add_blank,
|
||||
),
|
||||
self.tokenizer.text_to_ids(item["text"], item["language"]),
|
||||
dtype=np.int32,
|
||||
)
|
||||
|
||||
|
@ -281,24 +232,16 @@ class TTSDataset(Dataset):
|
|||
print(" | > Computing input sequences ...")
|
||||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
||||
sequence = np.asarray(
|
||||
text_to_sequence(
|
||||
item["text"],
|
||||
[self.cleaners],
|
||||
custom_symbols=self.custom_symbols,
|
||||
tp=self.characters,
|
||||
add_blank=self.add_blank,
|
||||
),
|
||||
self.tokenizer.text_to_ids(item["text"]),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.items[idx][0] = sequence
|
||||
|
||||
else:
|
||||
func_args = [
|
||||
self.phoneme_cache_path,
|
||||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
self.phoneme_language,
|
||||
self.custom_symbols,
|
||||
self.characters,
|
||||
self.add_blank,
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue