Refactor TTSDataset

This commit is contained in:
Eren Gölge 2021-11-30 15:50:18 +01:00
parent 4597d4e5b6
commit 176b712c1a
1 changed files with 451 additions and 233 deletions

View File

@ -2,7 +2,7 @@ import collections
import os import os
import random import random
from multiprocessing import Pool from multiprocessing import Pool
from typing import Dict, List from typing import Dict, List, Union
import numpy as np import numpy as np
import torch import torch
@ -14,6 +14,24 @@ from TTS.tts.utils.text import TTSTokenizer
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
def _parse_sample(item):
language_name = None
attn_file = None
if len(item) == 5:
text, wav_file, speaker_name, language_name, attn_file = item
elif len(item) == 4:
text, wav_file, speaker_name, language_name = item
elif len(item) == 3:
text, wav_file, speaker_name = item
else:
raise ValueError(" [!] Dataset cannot parse the sample.")
return text, wav_file, speaker_name, language_name, attn_file
def noise_augment_audio(wav):
return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
class TTSDataset(Dataset): class TTSDataset(Dataset):
def __init__( def __init__(
self, self,
@ -26,9 +44,12 @@ class TTSDataset(Dataset):
f0_cache_path: str = None, f0_cache_path: str = None,
return_wav: bool = False, return_wav: bool = False,
batch_group_size: int = 0, batch_group_size: int = 0,
min_seq_len: int = 0, min_text_len: int = 0,
max_seq_len: int = float("inf"), max_text_len: int = float("inf"),
min_audio_len: int = 0,
max_audio_len: int = float("inf"),
phoneme_cache_path: str = None, phoneme_cache_path: str = None,
precompute_num_workers: int = 0,
speaker_id_mapping: Dict = None, speaker_id_mapping: Dict = None,
d_vector_mapping: Dict = None, d_vector_mapping: Dict = None,
language_id_mapping: Dict = None, language_id_mapping: Dict = None,
@ -37,7 +58,7 @@ class TTSDataset(Dataset):
): ):
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
If you need something different, you can inherit and override. If you need something different, you can subclass and override.
Args: Args:
outputs_per_step (int): Number of time frames predicted per step. outputs_per_step (int): Number of time frames predicted per step.
@ -61,17 +82,24 @@ class TTSDataset(Dataset):
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
batch. Set 0 to disable. Defaults to 0. batch. Set 0 to disable. Defaults to 0.
min_seq_len (int): Minimum input sequence length to be processed min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored.
by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a Defaults to 0.
minimum input length due to its architecture. Defaults to 0.
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this. max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored.
It helps for controlling the VRAM usage against long input sequences. Especially models with Defaults to float("inf").
RNN layers are sensitive to input length. Defaults to `Inf`.
min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored.
Defaults to 0.
max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored.
The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to
this value if you encounter an OOM error in training. Defaults to float("inf").
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
separate file. Defaults to None. separate file. Defaults to None.
precompute_num_workers (int): Number of workers to precompute features. Defaults to 0.
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
embedding layer. Defaults to None. embedding layer. Defaults to None.
@ -83,15 +111,17 @@ class TTSDataset(Dataset):
""" """
super().__init__() super().__init__()
self.batch_group_size = batch_group_size self.batch_group_size = batch_group_size
self.items = meta_data self._samples = meta_data
self.outputs_per_step = outputs_per_step self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate self.sample_rate = ap.sample_rate
self.compute_linear_spec = compute_linear_spec self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav self.return_wav = return_wav
self.compute_f0 = compute_f0 self.compute_f0 = compute_f0
self.f0_cache_path = f0_cache_path self.f0_cache_path = f0_cache_path
self.min_seq_len = min_seq_len self.min_audio_len = min_audio_len
self.max_seq_len = max_seq_len self.max_audio_len = max_audio_len
self.min_text_len = min_text_len
self.max_text_len = max_text_len
self.ap = ap self.ap = ap
self.phoneme_cache_path = phoneme_cache_path self.phoneme_cache_path = phoneme_cache_path
self.speaker_id_mapping = speaker_id_mapping self.speaker_id_mapping = speaker_id_mapping
@ -100,112 +130,113 @@ class TTSDataset(Dataset):
self.use_noise_augment = use_noise_augment self.use_noise_augment = use_noise_augment
self.verbose = verbose self.verbose = verbose
self.input_seq_computed = False
self.rescue_item_idx = 1 self.rescue_item_idx = 1
self.pitch_computed = False self.pitch_computed = False
self.tokenizer = tokenizer self.tokenizer = tokenizer
if self.tokenizer.use_phonemes and not os.path.isdir(phoneme_cache_path): self.audio_lengths, self.text_lengths = self.compute_lengths(self.samples)
os.makedirs(phoneme_cache_path, exist_ok=True)
if self.tokenizer.use_phonemes:
self.phoneme_dataset = PhonemeDataset(
self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers
)
if compute_f0: if compute_f0:
self.pitch_extractor = PitchExtractor(self.items, verbose=verbose) self.f0_dataset = F0Dataset(
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
)
if self.verbose: if self.verbose:
self.print_logs() self.print_logs()
@property
def samples(self):
return self._samples
@samples.setter
def samples(self, new_samples):
self._samples = new_samples
if hasattr(self, "f0_dataset"):
self.f0_dataset.samples = new_samples
if hasattr(self, "phoneme_dataset"):
self.phoneme_dataset.samples = new_samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.load_data(idx)
def print_logs(self, level: int = 0) -> None: def print_logs(self, level: int = 0) -> None:
indent = "\t" * level indent = "\t" * level
print("\n") print("\n")
print(f"{indent}> DataLoader initialization") print(f"{indent}> DataLoader initialization")
print(f"{indent}| > Tokenizer:") print(f"{indent}| > Tokenizer:")
self.tokenizer.print_logs(level + 1) self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.items)}") print(f"{indent}| > Number of instances : {len(self.samples)}")
def load_wav(self, filename): def load_wav(self, filename):
audio = self.ap.load_wav(filename) waveform = self.ap.load_wav(filename)
return audio assert waveform.size > 0
return waveform
@staticmethod def get_phonemes(self, idx, text):
def load_np(filename): out_dict = self.phoneme_dataset[idx]
data = np.load(filename).astype("float32") assert text == out_dict["text"], f"{text} != {out_dict['text']}"
return data assert out_dict["token_ids"].size > 0
return out_dict
@staticmethod def get_f0(self, idx):
def _generate_and_cache_phoneme_sequence(text, tokenizer, cache_path): out_dict = self.f0_dataset[idx]
"""generate a phoneme sequence from text. _, wav_file, *_ = _parse_sample(self.samples[idx])
since the usage is for subsequent caching, we never add bos and assert wav_file == out_dict["audio_file"]
eos chars here. Instead we add those dynamically later; based on the return out_dict
config option."""
phonemes = tokenizer.text_to_ids(text)
phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes)
return phonemes
@staticmethod def get_attn_maks(self, attn_file):
def _load_or_generate_phoneme_sequence(wav_file, text, language, tokenizer, phoneme_cache_path): return np.load(attn_file)
file_name = os.path.splitext(os.path.basename(wav_file))[0]
# different names for normal phonemes and with blank chars. def get_token_ids(self, idx, text):
file_name_ext = "_phoneme.npy" if self.tokenizer.use_phonemes:
cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext) token_ids = self.get_phonemes(idx, text)["token_ids"]
try: else:
phonemes = np.load(cache_path) token_ids = self.tokenizer.text_to_ids(text)
except FileNotFoundError: return token_ids
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path)
except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path)
phonemes = np.asarray(phonemes, dtype=np.int32)
return phonemes
def load_data(self, idx): def load_data(self, idx):
item = self.items[idx] item = self.samples[idx]
raw_text = item["text"] raw_text = item["text"]
wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32) wav = np.asarray(self.load_wav(item[]), dtype=np.float32)
# apply noise for augmentation # apply noise for augmentation
if self.use_noise_augment: if self.use_noise_augment:
wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) wav = noise_augment_audio(wav)
if not self.input_seq_computed: # get token ids
if self.tokenizer.use_phonemes: token_ids = self.get_token_ids(idx, item["text"])
text = self._load_or_generate_phoneme_sequence(
item["audio_file"],
item["text"],
item["language"] if item["language"] else self.phoneme_language,
self.tokenizer,
self.phoneme_cache_path,
)
else:
text = np.asarray(
self.tokenizer.text_to_ids(item["text"], item["language"]),
dtype=np.int32,
)
assert text.size > 0, self.items[idx]["audio_file"]
assert wav.size > 0, self.items[idx]["audio_file"]
# get pre-computed attention maps
attn = None attn = None
if "alignment_file" in item: if "alignment_file" in item:
attn = np.load(item["alignment_file"]) attn = self.get_attn_mask(item["alignment_file"])
if len(text) > self.max_seq_len: # after phonemization the text length may change
# return a different sample if the phonemized # this is a shareful 🤭 hack to prevent longer phonemes
# text is longer than the threshold # TODO: find a better fix
# TODO: find a better fix if len(token_ids) > self.max_text_len:
return self.load_data(self.rescue_item_idx) return self.load_data(self.rescue_item_idx)
pitch = None # get f0 values
f0 = None
if self.compute_f0: if self.compute_f0:
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, item["audio_file"], self.f0_cache_path) f0 = self.get_f0(idx)["f0"]
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
sample = { sample = {
"raw_text": raw_text, "raw_text": raw_text,
"text": text, "token_ids": token_ids,
"wav": wav, "wav": wav,
"pitch": pitch, "pitch": f0,
"attn": attn, "attn": attn,
"item_idx": item["audio_file"], "item_idx": item["audio_file"],
"speaker_name": item["speaker_name"], "speaker_name": item["speaker_name"],
@ -215,105 +246,78 @@ class TTSDataset(Dataset):
return sample return sample
@staticmethod @staticmethod
def _phoneme_worker(args): def compute_lengths(samples):
item = args[0] audio_lengths = []
func_args = args[1] text_lengths = []
func_args[3] = ( for item in samples:
item["language"] if "language" in item and item["language"] else func_args[3] text, wav_file, *_ = _parse_sample(item)
) # override phoneme language if specified by the dataset formatter audio_lengths.append(os.path.getsize(wav_file) / 16 * 8) # assuming 16bit audio
phonemes = TTSDataset._load_or_generate_phoneme_sequence(item["audio_file"], item["text"], *func_args) text_lengths.append(len(text))
return phonemes audio_lengths = np.array(audio_lengths)
text_lengths = np.array(text_lengths)
return audio_lengths, text_lengths
def compute_input_seq(self, num_workers=0): @staticmethod
"""Compute the input sequences with multi-processing. def sort_and_filter_by_length(lengths:List[int], min_len:int, max_len:int):
Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" idxs = np.argsort(lengths) # ascending order
if not self.use_phonemes: ignore_idx = []
if self.verbose: keep_idx = []
print(" | > Computing input sequences ...")
for idx, item in enumerate(tqdm.tqdm(self.items)):
sequence = np.asarray(
self.tokenizer.text_to_ids(item["text"]),
dtype=np.int32,
)
self.items[idx][0] = sequence
else:
func_args = [
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
self.characters,
self.add_blank,
]
if self.verbose:
print(" | > Computing phonemes ...")
if num_workers == 0:
for idx, item in enumerate(tqdm.tqdm(self.items)):
phonemes = self._phoneme_worker([item, func_args])
self.items[idx][0] = phonemes
else:
with Pool(num_workers) as p:
phonemes = list(
tqdm.tqdm(
p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]),
total=len(self.items),
)
)
for idx, p in enumerate(phonemes):
self.items[idx][0] = p
def sort_and_filter_items(self, by_audio_len=False):
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
range.
Args:
by_audio_len (bool): if True, sort by audio length else by text length.
"""
# compute the target sequence length
if by_audio_len:
lengths = []
for item in self.items:
lengths.append(os.path.getsize(item["audio_file"]) / 16 * 8) # assuming 16bit audio
lengths = np.array(lengths)
else:
lengths = np.array([len(ins["text"]) for ins in self.items])
idxs = np.argsort(lengths)
new_items = []
ignored = []
for i, idx in enumerate(idxs): for i, idx in enumerate(idxs):
length = lengths[idx] length = lengths[idx]
if length < self.min_seq_len or length > self.max_seq_len: if length < min_len or length > max_len:
ignored.append(idx) ignore_idx.append(idx)
else: else:
new_items.append(self.items[idx]) keep_idx.append(idx)
return ignore_idx, keep_idx
@staticmethod
def create_buckets(samples, batch_group_size:int):
for i in range(len(samples) // batch_group_size):
offset = i * batch_group_size
end_offset = offset + batch_group_size
temp_items = samples[offset:end_offset]
random.shuffle(temp_items)
samples[offset:end_offset] = temp_items
return samples
def preprocess_samples(self):
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
range.
"""
# sort items based on the sequence length in ascending order
text_ignore_idx, text_keep_idx = self.sort_and_filter_by_length(self.text_lengths, self.min_text_len, self.max_text_len)
audio_ignore_idx, audio_keep_idx = self.sort_and_filter_by_length(self.audio_lengths, self.min_audio_len, self.max_audio_len)
keep_idx = list(set(audio_keep_idx) | set(text_keep_idx))
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))
samples = []
for idx in keep_idx:
samples.append(self.samples[idx])
if len(samples) == 0:
raise RuntimeError(" [!] No samples left")
# shuffle batch groups # shuffle batch groups
if self.batch_group_size > 0: # create batches with similar length items
for i in range(len(new_items) // self.batch_group_size): # the larger the `batch_group_size`, the higher the length variety in a batch.
offset = i * self.batch_group_size samples = self.create_buckets(samples, self.batch_group_size)
end_offset = offset + self.batch_group_size
temp_items = new_items[offset:end_offset] # update items to the new sorted items
random.shuffle(temp_items) self.samples = samples
new_items[offset:end_offset] = temp_items
self.items = new_items
if self.verbose: if self.verbose:
print(" | > Max length sequence: {}".format(np.max(lengths))) print(" | > Preprocessing samples")
print(" | > Min length sequence: {}".format(np.min(lengths))) print(" | > Max text length: {}".format(np.max(self.text_lengths)))
print(" | > Avg length sequence: {}".format(np.mean(lengths))) print(" | > Min text length: {}".format(np.min(self.text_lengths)))
print( print(" | > Avg text length: {}".format(np.mean(self.text_lengths)))
" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format( print(" | ")
self.max_seq_len, self.min_seq_len, len(ignored) print(" | > Max audio length: {}".format(np.max(self.audio_lengths)))
) print(" | > Min audio length: {}".format(np.min(self.audio_lengths)))
) print(" | > Avg audio length: {}".format(np.mean(self.audio_lengths)))
print(f" | > Num. instances discarded samples: {len(ignore_idx)}")
print(" | > Batch group size: {}.".format(self.batch_group_size)) print(" | > Batch group size: {}.".format(self.batch_group_size))
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
return self.load_data(idx)
@staticmethod @staticmethod
def _sort_batch(batch, text_lengths): def _sort_batch(batch, text_lengths):
"""Sort the batch by the input text length for RNN efficiency. """Sort the batch by the input text length for RNN efficiency.
@ -338,10 +342,10 @@ class TTSDataset(Dataset):
# Puts each data field into a tensor with outer dimension batch size # Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.abc.Mapping): if isinstance(batch[0], collections.abc.Mapping):
text_lengths = np.array([len(d["text"]) for d in batch]) token_ids_lengths = np.array([len(d["token_ids"]) for d in batch])
# sort items with text input length for RNN efficiency # sort items with text input length for RNN efficiency
batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths) batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths)
# convert list of dicts to dict of lists # convert list of dicts to dict of lists
batch = {k: [dic[k] for dic in batch] for k in batch[0]} batch = {k: [dic[k] for dic in batch] for k in batch[0]}
@ -383,7 +387,7 @@ class TTSDataset(Dataset):
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
# PAD sequences with longest instance in the batch # PAD sequences with longest instance in the batch
text = prepare_data(batch["text"]).astype(np.int32) text = prepare_data(batch["token_ids"]).astype(np.int32)
# PAD features with longest instance # PAD features with longest instance
mel = prepare_tensor(mel, self.outputs_per_step) mel = prepare_tensor(mel, self.outputs_per_step)
@ -392,12 +396,13 @@ class TTSDataset(Dataset):
mel = mel.transpose(0, 2, 1) mel = mel.transpose(0, 2, 1)
# convert things to pytorch # convert things to pytorch
text_lengths = torch.LongTensor(text_lengths) token_ids_lengths = torch.LongTensor(token_ids_lengths)
text = torch.LongTensor(text) text = torch.LongTensor(text)
mel = torch.FloatTensor(mel).contiguous() mel = torch.FloatTensor(mel).contiguous()
mel_lengths = torch.LongTensor(mel_lengths) mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets) stop_targets = torch.FloatTensor(stop_targets)
# speaker vectors
if d_vectors is not None: if d_vectors is not None:
d_vectors = torch.FloatTensor(d_vectors) d_vectors = torch.FloatTensor(d_vectors)
@ -408,14 +413,13 @@ class TTSDataset(Dataset):
language_ids = torch.LongTensor(language_ids) language_ids = torch.LongTensor(language_ids)
# compute linear spectrogram # compute linear spectrogram
linear = None
if self.compute_linear_spec: if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
linear = prepare_tensor(linear, self.outputs_per_step) linear = prepare_tensor(linear, self.outputs_per_step)
linear = linear.transpose(0, 2, 1) linear = linear.transpose(0, 2, 1)
assert mel.shape[1] == linear.shape[1] assert mel.shape[1] == linear.shape[1]
linear = torch.FloatTensor(linear).contiguous() linear = torch.FloatTensor(linear).contiguous()
else:
linear = None
# format waveforms # format waveforms
wav_padded = None wav_padded = None
@ -431,8 +435,7 @@ class TTSDataset(Dataset):
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
wav_padded.transpose_(1, 2) wav_padded.transpose_(1, 2)
# compute f0 # format F0
# TODO: compare perf in collate_fn vs in load_data
if self.compute_f0: if self.compute_f0:
pitch = prepare_data(batch["pitch"]) pitch = prepare_data(batch["pitch"])
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
@ -440,7 +443,8 @@ class TTSDataset(Dataset):
else: else:
pitch = None pitch = None
# collate attention alignments # format attention masks
attns = None
if batch["attn"][0] is not None: if batch["attn"][0] is not None:
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
for idx, attn in enumerate(attns): for idx, attn in enumerate(attns):
@ -451,12 +455,10 @@ class TTSDataset(Dataset):
attns[idx] = attn attns[idx] = attn
attns = prepare_tensor(attns, self.outputs_per_step) attns = prepare_tensor(attns, self.outputs_per_step)
attns = torch.FloatTensor(attns).unsqueeze(1) attns = torch.FloatTensor(attns).unsqueeze(1)
else:
attns = None
# TODO: return dictionary
return { return {
"text": text, "token_id": text,
"text_lengths": text_lengths, "token_id_lengths": token_ids_lengths,
"speaker_names": batch["speaker_name"], "speaker_names": batch["speaker_name"],
"linear": linear, "linear": linear,
"mel": mel, "mel": mel,
@ -482,22 +484,179 @@ class TTSDataset(Dataset):
) )
class PitchExtractor: class PhonemeDataset(Dataset):
"""Pitch Extractor for computing F0 from wav files. """Phoneme Dataset for converting input text to phonemes and then token IDs
At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data
loading latency. If `cache_path` is already present, it skips the pre-computation.
Args: Args:
items (List[List]): Dataset samples. samples (Union[List[List], List[Dict]]):
verbose (bool): Whether to print the progress. List of samples. Each sample is a list or a dict.
tokenizer (TTSTokenizer):
Tokenizer to convert input text to phonemes.
cache_path (str):
Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation.
precompute_num_workers (int):
Number of workers used for pre-computing the phonemes. Defaults to 0.
""" """
def __init__( def __init__(
self, self,
items: List[Dict], samples: Union[List[Dict], List[List]],
verbose=False, tokenizer: "TTSTokenizer",
cache_path: str,
precompute_num_workers=0,
): ):
self.items = items self.samples = samples
self.tokenizer = tokenizer
self.cache_path = cache_path
if cache_path is not None and not os.path.exists(cache_path):
os.makedirs(cache_path)
self.precompute(precompute_num_workers)
def __getitem__(self, index):
text, wav_file, *_ = _parse_sample(self.samples[index])
ids = self.compute_or_load(wav_file, text)
ph_hat = self.tokenizer.ids_to_text(ids)
return {"text": text, "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
def __len__(self):
return len(self.samples)
def compute_or_load(self, wav_file, text):
"""Compute phonemes for the given text.
If the phonemes are already cached, load them from cache.
"""
file_name = os.path.splitext(os.path.basename(wav_file))[0]
file_ext = "_phoneme.npy"
cache_path = os.path.join(self.cache_path, file_name + file_ext)
try:
ids = np.load(cache_path)
except FileNotFoundError:
ids = self.tokenizer.text_to_ids(text)
np.save(cache_path, ids)
return ids
def get_pad_id(self):
"""Get pad token ID for sequence padding"""
return self.tokenizer.pad_id
def precompute(self, num_workers=1):
"""Precompute phonemes for all samples.
We use pytorch dataloader because we are lazy.
"""
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
)
for _ in dataloder:
pbar.update(batch_size)
def collate_fn(self, batch):
ids = [item["token_ids"] for item in batch]
ids_lens = [item["token_ids_len"] for item in batch]
texts = [item["text"] for item in batch]
texts_hat = [item["ph_hat"] for item in batch]
ids_lens_max = max(ids_lens)
ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id())
for i, ids_len in enumerate(ids_lens):
ids_torch[i, :ids_len] = torch.LongTensor(ids[i])
return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch}
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> PhonemeDataset ")
print(f"{indent}| > Tokenizer:")
self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.samples)}")
class F0Dataset:
"""F0 Dataset for computing F0 from wav files in CPU
Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It
also computes the mean and std of F0 values if `normalize_f0` is True.
Args:
samples (Union[List[List], List[Dict]]):
List of samples. Each sample is a list or a dict.
ap (AudioProcessor):
AudioProcessor to compute F0 from wav files.
cache_path (str):
Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation.
Defaults to None.
precompute_num_workers (int):
Number of workers used for pre-computing the F0 values. Defaults to 0.
normalize_f0 (bool):
Whether to normalize F0 values by mean and std. Defaults to True.
"""
def __init__(
self,
samples: Union[List[List], List[Dict]],
ap: "AudioProcessor",
verbose=False,
cache_path: str = None,
precompute_num_workers=0,
normalize_f0=True,
):
self.samples = samples
self.ap = ap
self.verbose = verbose self.verbose = verbose
self.cache_path = cache_path
self.normalize_f0 = normalize_f0
self.pad_id = 0.0
self.mean = None self.mean = None
self.std = None self.std = None
if cache_path is not None and not os.path.exists(cache_path):
os.makedirs(cache_path)
self.precompute(precompute_num_workers)
if normalize_f0:
self.load_stats(cache_path)
def __getitem__(self, idx):
_, wav_file, *_ = _parse_sample(self.samples[idx])
f0 = self.compute_or_load(wav_file)
if self.normalize_f0:
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
f0 = self.normalize(f0)
return {"audio_file": wav_file, "f0": f0}
def __len__(self):
return len(self.samples)
def precompute(self, num_workers=0):
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
)
computed_data = []
for batch in dataloder:
f0 = batch["f0"]
computed_data.append([f for f in f0])
pbar.update(batch_size)
if self.normalize_f0:
computed_data = [tensor for batch in computed_data for tensor in batch] # flatten
pitch_mean, pitch_std = self.compute_pitch_stats(computed_data)
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
def get_pad_id(self):
return self.pad_id
@staticmethod @staticmethod
def create_pitch_file_path(wav_file, cache_path): def create_pitch_file_path(wav_file, cache_path):
@ -519,69 +678,128 @@ class PitchExtractor:
mean, std = np.mean(nonzeros), np.std(nonzeros) mean, std = np.mean(nonzeros), np.std(nonzeros)
return mean, std return mean, std
def normalize_pitch(self, pitch): def load_stats(self, cache_path):
stats_path = os.path.join(cache_path, "pitch_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32)
self.std = stats["std"].astype(np.float32)
def normalize(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0] zero_idxs = np.where(pitch == 0.0)[0]
pitch = pitch - self.mean pitch = pitch - self.mean
pitch = pitch / self.std pitch = pitch / self.std
pitch[zero_idxs] = 0.0 pitch[zero_idxs] = 0.0
return pitch return pitch
def denormalize_pitch(self, pitch): def denormalize(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0] zero_idxs = np.where(pitch == 0.0)[0]
pitch *= self.std pitch *= self.std
pitch += self.mean pitch += self.mean
pitch[zero_idxs] = 0.0 pitch[zero_idxs] = 0.0
return pitch return pitch
@staticmethod def compute_or_load(self, wav_file):
def load_or_compute_pitch(ap, wav_file, cache_path):
""" """
compute pitch and return a numpy array of pitch values compute pitch and return a numpy array of pitch values
""" """
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
if not os.path.exists(pitch_file): if not os.path.exists(pitch_file):
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
else: else:
pitch = np.load(pitch_file) pitch = np.load(pitch_file)
return pitch.astype(np.float32) return pitch.astype(np.float32)
@staticmethod def collate_fn(self, batch):
def _pitch_worker(args): audio_file = [item["audio_file"] for item in batch]
item = args[0] f0s = [item["f0"] for item in batch]
ap = args[1] f0_lens = [len(item["f0"]) for item in batch]
cache_path = args[2] f0_lens_max = max(f0_lens)
pitch_file = PitchExtractor.create_pitch_file_path(item["audio_file"], cache_path) f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id())
if not os.path.exists(pitch_file): for i, f0_len in enumerate(f0_lens):
pitch = PitchExtractor._compute_and_save_pitch(ap, item["audio_file"], pitch_file) f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i])
return pitch return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens}
return None
def compute_pitch(self, ap, cache_path, num_workers=0): def print_logs(self, level: int = 0) -> None:
"""Compute the input sequences with multi-processing. indent = "\t" * level
Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" print("\n")
if not os.path.exists(cache_path): print(f"{indent}> F0Dataset ")
os.makedirs(cache_path, exist_ok=True) print(f"{indent}| > Number of instances : {len(self.samples)}")
if self.verbose:
print(" | > Computing pitch features ...")
if num_workers == 0:
pitch_vecs = []
for _, item in enumerate(tqdm.tqdm(self.items)):
pitch_vecs += [self._pitch_worker([item, ap, cache_path])]
else:
with Pool(num_workers) as p:
pitch_vecs = list(
tqdm.tqdm(
p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]),
total=len(self.items),
)
)
pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs)
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
def load_pitch_stats(self, cache_path): if __name__ == "__main__":
stats_path = os.path.join(cache_path, "pitch_stats.npy") from torch.utils.data import DataLoader
stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32) from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig
self.std = stats["std"].astype(np.float32) from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.text.characters import IPAPhonemes
from TTS.tts.utils.text.phonemizers import ESpeak
dataset_config = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
path="/Users/erengolge/Projects/TTS/recipes/ljspeech/LJSpeech-1.1",
)
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
samples = train_samples + eval_samples
phonemizer = ESpeak(language="en-us")
tokenizer = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=phonemizer)
# ph_dataset = PhonemeDataset(samples, tokenizer, phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests")
# ph_dataset.precompute(num_workers=4)
# dataloader = DataLoader(ph_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=ph_dataset.collate_fn)
# for batch in dataloader:
# print(batch)
# break
audio_config = BaseAudioConfig(
sample_rate=22050,
win_length=1024,
hop_length=256,
num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=True,
trim_db=45,
mel_fmin=0,
mel_fmax=8000,
spec_gain=1.0,
signal_norm=False,
do_amp_to_db_linear=False,
)
ap = AudioProcessor.init_from_config(audio_config)
# f0_dataset = F0Dataset(samples, ap, cache_path="/Users/erengolge/Projects/TTS/f0_tests", verbose=False, precompute_num_workers=4)
# dataloader = DataLoader(f0_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=f0_dataset.collate_fn)
# for batch in dataloader:
# print(batch)
# breakpoint()
# break
dataset = TTSDataset(
outputs_per_step=1,
compute_linear_spec=False,
meta_data=samples,
ap=ap,
return_wav=False,
batch_group_size=0,
min_seq_len=0,
max_seq_len=500,
use_noise_augment=False,
verbose=True,
speaker_id_mapping=None,
d_vector_mapping=None,
compute_f0=True,
f0_cache_path="/Users/erengolge/Projects/TTS/f0_tests",
tokenizer=tokenizer,
phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests",
precompute_num_workers=4,
)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn)
for batch in dataloader:
print(batch)
break