Refactor TTSDataset

This commit is contained in:
Eren Gölge 2021-11-30 15:50:18 +01:00
parent 4597d4e5b6
commit 176b712c1a
1 changed files with 451 additions and 233 deletions

View File

@ -2,7 +2,7 @@ import collections
import os
import random
from multiprocessing import Pool
from typing import Dict, List
from typing import Dict, List, Union
import numpy as np
import torch
@ -14,6 +14,24 @@ from TTS.tts.utils.text import TTSTokenizer
from TTS.utils.audio import AudioProcessor
def _parse_sample(item):
language_name = None
attn_file = None
if len(item) == 5:
text, wav_file, speaker_name, language_name, attn_file = item
elif len(item) == 4:
text, wav_file, speaker_name, language_name = item
elif len(item) == 3:
text, wav_file, speaker_name = item
else:
raise ValueError(" [!] Dataset cannot parse the sample.")
return text, wav_file, speaker_name, language_name, attn_file
def noise_augment_audio(wav):
return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
class TTSDataset(Dataset):
def __init__(
self,
@ -26,9 +44,12 @@ class TTSDataset(Dataset):
f0_cache_path: str = None,
return_wav: bool = False,
batch_group_size: int = 0,
min_seq_len: int = 0,
max_seq_len: int = float("inf"),
min_text_len: int = 0,
max_text_len: int = float("inf"),
min_audio_len: int = 0,
max_audio_len: int = float("inf"),
phoneme_cache_path: str = None,
precompute_num_workers: int = 0,
speaker_id_mapping: Dict = None,
d_vector_mapping: Dict = None,
language_id_mapping: Dict = None,
@ -37,7 +58,7 @@ class TTSDataset(Dataset):
):
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
If you need something different, you can inherit and override.
If you need something different, you can subclass and override.
Args:
outputs_per_step (int): Number of time frames predicted per step.
@ -61,17 +82,24 @@ class TTSDataset(Dataset):
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
batch. Set 0 to disable. Defaults to 0.
min_seq_len (int): Minimum input sequence length to be processed
by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a
minimum input length due to its architecture. Defaults to 0.
min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored.
Defaults to 0.
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
It helps for controlling the VRAM usage against long input sequences. Especially models with
RNN layers are sensitive to input length. Defaults to `Inf`.
max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored.
Defaults to float("inf").
min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored.
Defaults to 0.
max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored.
The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to
this value if you encounter an OOM error in training. Defaults to float("inf").
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
separate file. Defaults to None.
precompute_num_workers (int): Number of workers to precompute features. Defaults to 0.
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
embedding layer. Defaults to None.
@ -83,15 +111,17 @@ class TTSDataset(Dataset):
"""
super().__init__()
self.batch_group_size = batch_group_size
self.items = meta_data
self._samples = meta_data
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav
self.compute_f0 = compute_f0
self.f0_cache_path = f0_cache_path
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.min_audio_len = min_audio_len
self.max_audio_len = max_audio_len
self.min_text_len = min_text_len
self.max_text_len = max_text_len
self.ap = ap
self.phoneme_cache_path = phoneme_cache_path
self.speaker_id_mapping = speaker_id_mapping
@ -100,112 +130,113 @@ class TTSDataset(Dataset):
self.use_noise_augment = use_noise_augment
self.verbose = verbose
self.input_seq_computed = False
self.rescue_item_idx = 1
self.pitch_computed = False
self.tokenizer = tokenizer
if self.tokenizer.use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True)
self.audio_lengths, self.text_lengths = self.compute_lengths(self.samples)
if self.tokenizer.use_phonemes:
self.phoneme_dataset = PhonemeDataset(
self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers
)
if compute_f0:
self.pitch_extractor = PitchExtractor(self.items, verbose=verbose)
self.f0_dataset = F0Dataset(
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
)
if self.verbose:
self.print_logs()
@property
def samples(self):
return self._samples
@samples.setter
def samples(self, new_samples):
self._samples = new_samples
if hasattr(self, "f0_dataset"):
self.f0_dataset.samples = new_samples
if hasattr(self, "phoneme_dataset"):
self.phoneme_dataset.samples = new_samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.load_data(idx)
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> DataLoader initialization")
print(f"{indent}| > Tokenizer:")
self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.items)}")
print(f"{indent}| > Number of instances : {len(self.samples)}")
def load_wav(self, filename):
audio = self.ap.load_wav(filename)
return audio
waveform = self.ap.load_wav(filename)
assert waveform.size > 0
return waveform
@staticmethod
def load_np(filename):
data = np.load(filename).astype("float32")
return data
def get_phonemes(self, idx, text):
out_dict = self.phoneme_dataset[idx]
assert text == out_dict["text"], f"{text} != {out_dict['text']}"
assert out_dict["token_ids"].size > 0
return out_dict
@staticmethod
def _generate_and_cache_phoneme_sequence(text, tokenizer, cache_path):
"""generate a phoneme sequence from text.
since the usage is for subsequent caching, we never add bos and
eos chars here. Instead we add those dynamically later; based on the
config option."""
phonemes = tokenizer.text_to_ids(text)
phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes)
return phonemes
def get_f0(self, idx):
out_dict = self.f0_dataset[idx]
_, wav_file, *_ = _parse_sample(self.samples[idx])
assert wav_file == out_dict["audio_file"]
return out_dict
@staticmethod
def _load_or_generate_phoneme_sequence(wav_file, text, language, tokenizer, phoneme_cache_path):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
def get_attn_maks(self, attn_file):
return np.load(attn_file)
# different names for normal phonemes and with blank chars.
file_name_ext = "_phoneme.npy"
cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext)
try:
phonemes = np.load(cache_path)
except FileNotFoundError:
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path)
except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path)
phonemes = np.asarray(phonemes, dtype=np.int32)
return phonemes
def get_token_ids(self, idx, text):
if self.tokenizer.use_phonemes:
token_ids = self.get_phonemes(idx, text)["token_ids"]
else:
token_ids = self.tokenizer.text_to_ids(text)
return token_ids
def load_data(self, idx):
item = self.items[idx]
item = self.samples[idx]
raw_text = item["text"]
wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32)
wav = np.asarray(self.load_wav(item[]), dtype=np.float32)
# apply noise for augmentation
if self.use_noise_augment:
wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
wav = noise_augment_audio(wav)
if not self.input_seq_computed:
if self.tokenizer.use_phonemes:
text = self._load_or_generate_phoneme_sequence(
item["audio_file"],
item["text"],
item["language"] if item["language"] else self.phoneme_language,
self.tokenizer,
self.phoneme_cache_path,
)
else:
text = np.asarray(
self.tokenizer.text_to_ids(item["text"], item["language"]),
dtype=np.int32,
)
assert text.size > 0, self.items[idx]["audio_file"]
assert wav.size > 0, self.items[idx]["audio_file"]
# get token ids
token_ids = self.get_token_ids(idx, item["text"])
# get pre-computed attention maps
attn = None
if "alignment_file" in item:
attn = np.load(item["alignment_file"])
attn = self.get_attn_mask(item["alignment_file"])
if len(text) > self.max_seq_len:
# return a different sample if the phonemized
# text is longer than the threshold
# TODO: find a better fix
# after phonemization the text length may change
# this is a shareful 🤭 hack to prevent longer phonemes
# TODO: find a better fix
if len(token_ids) > self.max_text_len:
return self.load_data(self.rescue_item_idx)
pitch = None
# get f0 values
f0 = None
if self.compute_f0:
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, item["audio_file"], self.f0_cache_path)
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
f0 = self.get_f0(idx)["f0"]
sample = {
"raw_text": raw_text,
"text": text,
"token_ids": token_ids,
"wav": wav,
"pitch": pitch,
"pitch": f0,
"attn": attn,
"item_idx": item["audio_file"],
"speaker_name": item["speaker_name"],
@ -215,105 +246,78 @@ class TTSDataset(Dataset):
return sample
@staticmethod
def _phoneme_worker(args):
item = args[0]
func_args = args[1]
func_args[3] = (
item["language"] if "language" in item and item["language"] else func_args[3]
) # override phoneme language if specified by the dataset formatter
phonemes = TTSDataset._load_or_generate_phoneme_sequence(item["audio_file"], item["text"], *func_args)
return phonemes
def compute_lengths(samples):
audio_lengths = []
text_lengths = []
for item in samples:
text, wav_file, *_ = _parse_sample(item)
audio_lengths.append(os.path.getsize(wav_file) / 16 * 8) # assuming 16bit audio
text_lengths.append(len(text))
audio_lengths = np.array(audio_lengths)
text_lengths = np.array(text_lengths)
return audio_lengths, text_lengths
def compute_input_seq(self, num_workers=0):
"""Compute the input sequences with multi-processing.
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
if not self.use_phonemes:
if self.verbose:
print(" | > Computing input sequences ...")
for idx, item in enumerate(tqdm.tqdm(self.items)):
sequence = np.asarray(
self.tokenizer.text_to_ids(item["text"]),
dtype=np.int32,
)
self.items[idx][0] = sequence
else:
func_args = [
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
self.characters,
self.add_blank,
]
if self.verbose:
print(" | > Computing phonemes ...")
if num_workers == 0:
for idx, item in enumerate(tqdm.tqdm(self.items)):
phonemes = self._phoneme_worker([item, func_args])
self.items[idx][0] = phonemes
else:
with Pool(num_workers) as p:
phonemes = list(
tqdm.tqdm(
p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]),
total=len(self.items),
)
)
for idx, p in enumerate(phonemes):
self.items[idx][0] = p
def sort_and_filter_items(self, by_audio_len=False):
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
range.
Args:
by_audio_len (bool): if True, sort by audio length else by text length.
"""
# compute the target sequence length
if by_audio_len:
lengths = []
for item in self.items:
lengths.append(os.path.getsize(item["audio_file"]) / 16 * 8) # assuming 16bit audio
lengths = np.array(lengths)
else:
lengths = np.array([len(ins["text"]) for ins in self.items])
idxs = np.argsort(lengths)
new_items = []
ignored = []
@staticmethod
def sort_and_filter_by_length(lengths:List[int], min_len:int, max_len:int):
idxs = np.argsort(lengths) # ascending order
ignore_idx = []
keep_idx = []
for i, idx in enumerate(idxs):
length = lengths[idx]
if length < self.min_seq_len or length > self.max_seq_len:
ignored.append(idx)
if length < min_len or length > max_len:
ignore_idx.append(idx)
else:
new_items.append(self.items[idx])
keep_idx.append(idx)
return ignore_idx, keep_idx
@staticmethod
def create_buckets(samples, batch_group_size:int):
for i in range(len(samples) // batch_group_size):
offset = i * batch_group_size
end_offset = offset + batch_group_size
temp_items = samples[offset:end_offset]
random.shuffle(temp_items)
samples[offset:end_offset] = temp_items
return samples
def preprocess_samples(self):
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
range.
"""
# sort items based on the sequence length in ascending order
text_ignore_idx, text_keep_idx = self.sort_and_filter_by_length(self.text_lengths, self.min_text_len, self.max_text_len)
audio_ignore_idx, audio_keep_idx = self.sort_and_filter_by_length(self.audio_lengths, self.min_audio_len, self.max_audio_len)
keep_idx = list(set(audio_keep_idx) | set(text_keep_idx))
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))
samples = []
for idx in keep_idx:
samples.append(self.samples[idx])
if len(samples) == 0:
raise RuntimeError(" [!] No samples left")
# shuffle batch groups
if self.batch_group_size > 0:
for i in range(len(new_items) // self.batch_group_size):
offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size
temp_items = new_items[offset:end_offset]
random.shuffle(temp_items)
new_items[offset:end_offset] = temp_items
self.items = new_items
# create batches with similar length items
# the larger the `batch_group_size`, the higher the length variety in a batch.
samples = self.create_buckets(samples, self.batch_group_size)
# update items to the new sorted items
self.samples = samples
if self.verbose:
print(" | > Max length sequence: {}".format(np.max(lengths)))
print(" | > Min length sequence: {}".format(np.min(lengths)))
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
print(
" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format(
self.max_seq_len, self.min_seq_len, len(ignored)
)
)
print(" | > Preprocessing samples")
print(" | > Max text length: {}".format(np.max(self.text_lengths)))
print(" | > Min text length: {}".format(np.min(self.text_lengths)))
print(" | > Avg text length: {}".format(np.mean(self.text_lengths)))
print(" | ")
print(" | > Max audio length: {}".format(np.max(self.audio_lengths)))
print(" | > Min audio length: {}".format(np.min(self.audio_lengths)))
print(" | > Avg audio length: {}".format(np.mean(self.audio_lengths)))
print(f" | > Num. instances discarded samples: {len(ignore_idx)}")
print(" | > Batch group size: {}.".format(self.batch_group_size))
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
return self.load_data(idx)
@staticmethod
def _sort_batch(batch, text_lengths):
"""Sort the batch by the input text length for RNN efficiency.
@ -338,10 +342,10 @@ class TTSDataset(Dataset):
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.abc.Mapping):
text_lengths = np.array([len(d["text"]) for d in batch])
token_ids_lengths = np.array([len(d["token_ids"]) for d in batch])
# sort items with text input length for RNN efficiency
batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths)
batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths)
# convert list of dicts to dict of lists
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
@ -383,7 +387,7 @@ class TTSDataset(Dataset):
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
# PAD sequences with longest instance in the batch
text = prepare_data(batch["text"]).astype(np.int32)
text = prepare_data(batch["token_ids"]).astype(np.int32)
# PAD features with longest instance
mel = prepare_tensor(mel, self.outputs_per_step)
@ -392,12 +396,13 @@ class TTSDataset(Dataset):
mel = mel.transpose(0, 2, 1)
# convert things to pytorch
text_lengths = torch.LongTensor(text_lengths)
token_ids_lengths = torch.LongTensor(token_ids_lengths)
text = torch.LongTensor(text)
mel = torch.FloatTensor(mel).contiguous()
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
# speaker vectors
if d_vectors is not None:
d_vectors = torch.FloatTensor(d_vectors)
@ -408,14 +413,13 @@ class TTSDataset(Dataset):
language_ids = torch.LongTensor(language_ids)
# compute linear spectrogram
linear = None
if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
linear = prepare_tensor(linear, self.outputs_per_step)
linear = linear.transpose(0, 2, 1)
assert mel.shape[1] == linear.shape[1]
linear = torch.FloatTensor(linear).contiguous()
else:
linear = None
# format waveforms
wav_padded = None
@ -431,8 +435,7 @@ class TTSDataset(Dataset):
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
wav_padded.transpose_(1, 2)
# compute f0
# TODO: compare perf in collate_fn vs in load_data
# format F0
if self.compute_f0:
pitch = prepare_data(batch["pitch"])
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
@ -440,7 +443,8 @@ class TTSDataset(Dataset):
else:
pitch = None
# collate attention alignments
# format attention masks
attns = None
if batch["attn"][0] is not None:
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
for idx, attn in enumerate(attns):
@ -451,12 +455,10 @@ class TTSDataset(Dataset):
attns[idx] = attn
attns = prepare_tensor(attns, self.outputs_per_step)
attns = torch.FloatTensor(attns).unsqueeze(1)
else:
attns = None
# TODO: return dictionary
return {
"text": text,
"text_lengths": text_lengths,
"token_id": text,
"token_id_lengths": token_ids_lengths,
"speaker_names": batch["speaker_name"],
"linear": linear,
"mel": mel,
@ -482,22 +484,179 @@ class TTSDataset(Dataset):
)
class PitchExtractor:
"""Pitch Extractor for computing F0 from wav files.
class PhonemeDataset(Dataset):
"""Phoneme Dataset for converting input text to phonemes and then token IDs
At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data
loading latency. If `cache_path` is already present, it skips the pre-computation.
Args:
items (List[List]): Dataset samples.
verbose (bool): Whether to print the progress.
samples (Union[List[List], List[Dict]]):
List of samples. Each sample is a list or a dict.
tokenizer (TTSTokenizer):
Tokenizer to convert input text to phonemes.
cache_path (str):
Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation.
precompute_num_workers (int):
Number of workers used for pre-computing the phonemes. Defaults to 0.
"""
def __init__(
self,
items: List[Dict],
verbose=False,
samples: Union[List[Dict], List[List]],
tokenizer: "TTSTokenizer",
cache_path: str,
precompute_num_workers=0,
):
self.items = items
self.samples = samples
self.tokenizer = tokenizer
self.cache_path = cache_path
if cache_path is not None and not os.path.exists(cache_path):
os.makedirs(cache_path)
self.precompute(precompute_num_workers)
def __getitem__(self, index):
text, wav_file, *_ = _parse_sample(self.samples[index])
ids = self.compute_or_load(wav_file, text)
ph_hat = self.tokenizer.ids_to_text(ids)
return {"text": text, "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
def __len__(self):
return len(self.samples)
def compute_or_load(self, wav_file, text):
"""Compute phonemes for the given text.
If the phonemes are already cached, load them from cache.
"""
file_name = os.path.splitext(os.path.basename(wav_file))[0]
file_ext = "_phoneme.npy"
cache_path = os.path.join(self.cache_path, file_name + file_ext)
try:
ids = np.load(cache_path)
except FileNotFoundError:
ids = self.tokenizer.text_to_ids(text)
np.save(cache_path, ids)
return ids
def get_pad_id(self):
"""Get pad token ID for sequence padding"""
return self.tokenizer.pad_id
def precompute(self, num_workers=1):
"""Precompute phonemes for all samples.
We use pytorch dataloader because we are lazy.
"""
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
)
for _ in dataloder:
pbar.update(batch_size)
def collate_fn(self, batch):
ids = [item["token_ids"] for item in batch]
ids_lens = [item["token_ids_len"] for item in batch]
texts = [item["text"] for item in batch]
texts_hat = [item["ph_hat"] for item in batch]
ids_lens_max = max(ids_lens)
ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id())
for i, ids_len in enumerate(ids_lens):
ids_torch[i, :ids_len] = torch.LongTensor(ids[i])
return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch}
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> PhonemeDataset ")
print(f"{indent}| > Tokenizer:")
self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.samples)}")
class F0Dataset:
"""F0 Dataset for computing F0 from wav files in CPU
Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It
also computes the mean and std of F0 values if `normalize_f0` is True.
Args:
samples (Union[List[List], List[Dict]]):
List of samples. Each sample is a list or a dict.
ap (AudioProcessor):
AudioProcessor to compute F0 from wav files.
cache_path (str):
Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation.
Defaults to None.
precompute_num_workers (int):
Number of workers used for pre-computing the F0 values. Defaults to 0.
normalize_f0 (bool):
Whether to normalize F0 values by mean and std. Defaults to True.
"""
def __init__(
self,
samples: Union[List[List], List[Dict]],
ap: "AudioProcessor",
verbose=False,
cache_path: str = None,
precompute_num_workers=0,
normalize_f0=True,
):
self.samples = samples
self.ap = ap
self.verbose = verbose
self.cache_path = cache_path
self.normalize_f0 = normalize_f0
self.pad_id = 0.0
self.mean = None
self.std = None
if cache_path is not None and not os.path.exists(cache_path):
os.makedirs(cache_path)
self.precompute(precompute_num_workers)
if normalize_f0:
self.load_stats(cache_path)
def __getitem__(self, idx):
_, wav_file, *_ = _parse_sample(self.samples[idx])
f0 = self.compute_or_load(wav_file)
if self.normalize_f0:
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
f0 = self.normalize(f0)
return {"audio_file": wav_file, "f0": f0}
def __len__(self):
return len(self.samples)
def precompute(self, num_workers=0):
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
)
computed_data = []
for batch in dataloder:
f0 = batch["f0"]
computed_data.append([f for f in f0])
pbar.update(batch_size)
if self.normalize_f0:
computed_data = [tensor for batch in computed_data for tensor in batch] # flatten
pitch_mean, pitch_std = self.compute_pitch_stats(computed_data)
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
def get_pad_id(self):
return self.pad_id
@staticmethod
def create_pitch_file_path(wav_file, cache_path):
@ -519,69 +678,128 @@ class PitchExtractor:
mean, std = np.mean(nonzeros), np.std(nonzeros)
return mean, std
def normalize_pitch(self, pitch):
def load_stats(self, cache_path):
stats_path = os.path.join(cache_path, "pitch_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32)
self.std = stats["std"].astype(np.float32)
def normalize(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0]
pitch = pitch - self.mean
pitch = pitch / self.std
pitch[zero_idxs] = 0.0
return pitch
def denormalize_pitch(self, pitch):
def denormalize(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0]
pitch *= self.std
pitch += self.mean
pitch[zero_idxs] = 0.0
return pitch
@staticmethod
def load_or_compute_pitch(ap, wav_file, cache_path):
def compute_or_load(self, wav_file):
"""
compute pitch and return a numpy array of pitch values
"""
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
if not os.path.exists(pitch_file):
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
else:
pitch = np.load(pitch_file)
return pitch.astype(np.float32)
@staticmethod
def _pitch_worker(args):
item = args[0]
ap = args[1]
cache_path = args[2]
pitch_file = PitchExtractor.create_pitch_file_path(item["audio_file"], cache_path)
if not os.path.exists(pitch_file):
pitch = PitchExtractor._compute_and_save_pitch(ap, item["audio_file"], pitch_file)
return pitch
return None
def collate_fn(self, batch):
audio_file = [item["audio_file"] for item in batch]
f0s = [item["f0"] for item in batch]
f0_lens = [len(item["f0"]) for item in batch]
f0_lens_max = max(f0_lens)
f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id())
for i, f0_len in enumerate(f0_lens):
f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i])
return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens}
def compute_pitch(self, ap, cache_path, num_workers=0):
"""Compute the input sequences with multi-processing.
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
if not os.path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> F0Dataset ")
print(f"{indent}| > Number of instances : {len(self.samples)}")
if self.verbose:
print(" | > Computing pitch features ...")
if num_workers == 0:
pitch_vecs = []
for _, item in enumerate(tqdm.tqdm(self.items)):
pitch_vecs += [self._pitch_worker([item, ap, cache_path])]
else:
with Pool(num_workers) as p:
pitch_vecs = list(
tqdm.tqdm(
p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]),
total=len(self.items),
)
)
pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs)
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
def load_pitch_stats(self, cache_path):
stats_path = os.path.join(cache_path, "pitch_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32)
self.std = stats["std"].astype(np.float32)
if __name__ == "__main__":
from torch.utils.data import DataLoader
from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.text.characters import IPAPhonemes
from TTS.tts.utils.text.phonemizers import ESpeak
dataset_config = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
path="/Users/erengolge/Projects/TTS/recipes/ljspeech/LJSpeech-1.1",
)
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
samples = train_samples + eval_samples
phonemizer = ESpeak(language="en-us")
tokenizer = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=phonemizer)
# ph_dataset = PhonemeDataset(samples, tokenizer, phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests")
# ph_dataset.precompute(num_workers=4)
# dataloader = DataLoader(ph_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=ph_dataset.collate_fn)
# for batch in dataloader:
# print(batch)
# break
audio_config = BaseAudioConfig(
sample_rate=22050,
win_length=1024,
hop_length=256,
num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=True,
trim_db=45,
mel_fmin=0,
mel_fmax=8000,
spec_gain=1.0,
signal_norm=False,
do_amp_to_db_linear=False,
)
ap = AudioProcessor.init_from_config(audio_config)
# f0_dataset = F0Dataset(samples, ap, cache_path="/Users/erengolge/Projects/TTS/f0_tests", verbose=False, precompute_num_workers=4)
# dataloader = DataLoader(f0_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=f0_dataset.collate_fn)
# for batch in dataloader:
# print(batch)
# breakpoint()
# break
dataset = TTSDataset(
outputs_per_step=1,
compute_linear_spec=False,
meta_data=samples,
ap=ap,
return_wav=False,
batch_group_size=0,
min_seq_len=0,
max_seq_len=500,
use_noise_augment=False,
verbose=True,
speaker_id_mapping=None,
d_vector_mapping=None,
compute_f0=True,
f0_cache_path="/Users/erengolge/Projects/TTS/f0_tests",
tokenizer=tokenizer,
phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests",
precompute_num_workers=4,
)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn)
for batch in dataloader:
print(batch)
break