From 176b712c1a40cf630da9a77f1826836723c40fde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 30 Nov 2021 15:50:18 +0100 Subject: [PATCH] =?UTF-8?q?Refactor=20TTSDataset=20=E2=9A=A1=EF=B8=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- TTS/tts/datasets/dataset.py | 684 ++++++++++++++++++++++++------------ 1 file changed, 451 insertions(+), 233 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 8c21d7d0..60b514c2 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -2,7 +2,7 @@ import collections import os import random from multiprocessing import Pool -from typing import Dict, List +from typing import Dict, List, Union import numpy as np import torch @@ -14,6 +14,24 @@ from TTS.tts.utils.text import TTSTokenizer from TTS.utils.audio import AudioProcessor +def _parse_sample(item): + language_name = None + attn_file = None + if len(item) == 5: + text, wav_file, speaker_name, language_name, attn_file = item + elif len(item) == 4: + text, wav_file, speaker_name, language_name = item + elif len(item) == 3: + text, wav_file, speaker_name = item + else: + raise ValueError(" [!] Dataset cannot parse the sample.") + return text, wav_file, speaker_name, language_name, attn_file + + +def noise_augment_audio(wav): + return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) + + class TTSDataset(Dataset): def __init__( self, @@ -26,9 +44,12 @@ class TTSDataset(Dataset): f0_cache_path: str = None, return_wav: bool = False, batch_group_size: int = 0, - min_seq_len: int = 0, - max_seq_len: int = float("inf"), + min_text_len: int = 0, + max_text_len: int = float("inf"), + min_audio_len: int = 0, + max_audio_len: int = float("inf"), phoneme_cache_path: str = None, + precompute_num_workers: int = 0, speaker_id_mapping: Dict = None, d_vector_mapping: Dict = None, language_id_mapping: Dict = None, @@ -37,7 +58,7 @@ class TTSDataset(Dataset): ): """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. - If you need something different, you can inherit and override. + If you need something different, you can subclass and override. Args: outputs_per_step (int): Number of time frames predicted per step. @@ -61,17 +82,24 @@ class TTSDataset(Dataset): sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a batch. Set 0 to disable. Defaults to 0. - min_seq_len (int): Minimum input sequence length to be processed - by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a - minimum input length due to its architecture. Defaults to 0. + min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored. + Defaults to 0. - max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this. - It helps for controlling the VRAM usage against long input sequences. Especially models with - RNN layers are sensitive to input length. Defaults to `Inf`. + max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored. + Defaults to float("inf"). + + min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored. + Defaults to 0. + + max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored. + The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to + this value if you encounter an OOM error in training. Defaults to float("inf"). phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a separate file. Defaults to None. + precompute_num_workers (int): Number of workers to precompute features. Defaults to 0. + speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the embedding layer. Defaults to None. @@ -83,15 +111,17 @@ class TTSDataset(Dataset): """ super().__init__() self.batch_group_size = batch_group_size - self.items = meta_data + self._samples = meta_data self.outputs_per_step = outputs_per_step self.sample_rate = ap.sample_rate self.compute_linear_spec = compute_linear_spec self.return_wav = return_wav self.compute_f0 = compute_f0 self.f0_cache_path = f0_cache_path - self.min_seq_len = min_seq_len - self.max_seq_len = max_seq_len + self.min_audio_len = min_audio_len + self.max_audio_len = max_audio_len + self.min_text_len = min_text_len + self.max_text_len = max_text_len self.ap = ap self.phoneme_cache_path = phoneme_cache_path self.speaker_id_mapping = speaker_id_mapping @@ -100,112 +130,113 @@ class TTSDataset(Dataset): self.use_noise_augment = use_noise_augment self.verbose = verbose - self.input_seq_computed = False self.rescue_item_idx = 1 self.pitch_computed = False self.tokenizer = tokenizer - if self.tokenizer.use_phonemes and not os.path.isdir(phoneme_cache_path): - os.makedirs(phoneme_cache_path, exist_ok=True) + self.audio_lengths, self.text_lengths = self.compute_lengths(self.samples) + + if self.tokenizer.use_phonemes: + self.phoneme_dataset = PhonemeDataset( + self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers + ) + if compute_f0: - self.pitch_extractor = PitchExtractor(self.items, verbose=verbose) + self.f0_dataset = F0Dataset( + self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers + ) if self.verbose: self.print_logs() + @property + def samples(self): + return self._samples + + @samples.setter + def samples(self, new_samples): + self._samples = new_samples + if hasattr(self, "f0_dataset"): + self.f0_dataset.samples = new_samples + if hasattr(self, "phoneme_dataset"): + self.phoneme_dataset.samples = new_samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.load_data(idx) + def print_logs(self, level: int = 0) -> None: indent = "\t" * level print("\n") print(f"{indent}> DataLoader initialization") print(f"{indent}| > Tokenizer:") self.tokenizer.print_logs(level + 1) - print(f"{indent}| > Number of instances : {len(self.items)}") + print(f"{indent}| > Number of instances : {len(self.samples)}") def load_wav(self, filename): - audio = self.ap.load_wav(filename) - return audio + waveform = self.ap.load_wav(filename) + assert waveform.size > 0 + return waveform - @staticmethod - def load_np(filename): - data = np.load(filename).astype("float32") - return data + def get_phonemes(self, idx, text): + out_dict = self.phoneme_dataset[idx] + assert text == out_dict["text"], f"{text} != {out_dict['text']}" + assert out_dict["token_ids"].size > 0 + return out_dict - @staticmethod - def _generate_and_cache_phoneme_sequence(text, tokenizer, cache_path): - """generate a phoneme sequence from text. - since the usage is for subsequent caching, we never add bos and - eos chars here. Instead we add those dynamically later; based on the - config option.""" - phonemes = tokenizer.text_to_ids(text) - phonemes = np.asarray(phonemes, dtype=np.int32) - np.save(cache_path, phonemes) - return phonemes + def get_f0(self, idx): + out_dict = self.f0_dataset[idx] + _, wav_file, *_ = _parse_sample(self.samples[idx]) + assert wav_file == out_dict["audio_file"] + return out_dict - @staticmethod - def _load_or_generate_phoneme_sequence(wav_file, text, language, tokenizer, phoneme_cache_path): - file_name = os.path.splitext(os.path.basename(wav_file))[0] + def get_attn_maks(self, attn_file): + return np.load(attn_file) - # different names for normal phonemes and with blank chars. - file_name_ext = "_phoneme.npy" - cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext) - try: - phonemes = np.load(cache_path) - except FileNotFoundError: - phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path) - except (ValueError, IOError): - print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file)) - phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path) - phonemes = np.asarray(phonemes, dtype=np.int32) - return phonemes + def get_token_ids(self, idx, text): + if self.tokenizer.use_phonemes: + token_ids = self.get_phonemes(idx, text)["token_ids"] + else: + token_ids = self.tokenizer.text_to_ids(text) + return token_ids def load_data(self, idx): - item = self.items[idx] + item = self.samples[idx] + raw_text = item["text"] - wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32) + wav = np.asarray(self.load_wav(item[]), dtype=np.float32) # apply noise for augmentation if self.use_noise_augment: - wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) + wav = noise_augment_audio(wav) - if not self.input_seq_computed: - if self.tokenizer.use_phonemes: - text = self._load_or_generate_phoneme_sequence( - item["audio_file"], - item["text"], - item["language"] if item["language"] else self.phoneme_language, - self.tokenizer, - self.phoneme_cache_path, - ) - else: - text = np.asarray( - self.tokenizer.text_to_ids(item["text"], item["language"]), - dtype=np.int32, - ) - - assert text.size > 0, self.items[idx]["audio_file"] - assert wav.size > 0, self.items[idx]["audio_file"] + # get token ids + token_ids = self.get_token_ids(idx, item["text"]) + # get pre-computed attention maps attn = None if "alignment_file" in item: - attn = np.load(item["alignment_file"]) + attn = self.get_attn_mask(item["alignment_file"]) - if len(text) > self.max_seq_len: - # return a different sample if the phonemized - # text is longer than the threshold - # TODO: find a better fix + # after phonemization the text length may change + # this is a shareful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len: return self.load_data(self.rescue_item_idx) - pitch = None + # get f0 values + f0 = None if self.compute_f0: - pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, item["audio_file"], self.f0_cache_path) - pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32)) + f0 = self.get_f0(idx)["f0"] sample = { "raw_text": raw_text, - "text": text, + "token_ids": token_ids, "wav": wav, - "pitch": pitch, + "pitch": f0, "attn": attn, "item_idx": item["audio_file"], "speaker_name": item["speaker_name"], @@ -215,105 +246,78 @@ class TTSDataset(Dataset): return sample @staticmethod - def _phoneme_worker(args): - item = args[0] - func_args = args[1] - func_args[3] = ( - item["language"] if "language" in item and item["language"] else func_args[3] - ) # override phoneme language if specified by the dataset formatter - phonemes = TTSDataset._load_or_generate_phoneme_sequence(item["audio_file"], item["text"], *func_args) - return phonemes + def compute_lengths(samples): + audio_lengths = [] + text_lengths = [] + for item in samples: + text, wav_file, *_ = _parse_sample(item) + audio_lengths.append(os.path.getsize(wav_file) / 16 * 8) # assuming 16bit audio + text_lengths.append(len(text)) + audio_lengths = np.array(audio_lengths) + text_lengths = np.array(text_lengths) + return audio_lengths, text_lengths - def compute_input_seq(self, num_workers=0): - """Compute the input sequences with multi-processing. - Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" - if not self.use_phonemes: - if self.verbose: - print(" | > Computing input sequences ...") - for idx, item in enumerate(tqdm.tqdm(self.items)): - sequence = np.asarray( - self.tokenizer.text_to_ids(item["text"]), - dtype=np.int32, - ) - self.items[idx][0] = sequence - else: - func_args = [ - self.phoneme_cache_path, - self.enable_eos_bos, - self.cleaners, - self.phoneme_language, - self.characters, - self.add_blank, - ] - if self.verbose: - print(" | > Computing phonemes ...") - if num_workers == 0: - for idx, item in enumerate(tqdm.tqdm(self.items)): - phonemes = self._phoneme_worker([item, func_args]) - self.items[idx][0] = phonemes - else: - with Pool(num_workers) as p: - phonemes = list( - tqdm.tqdm( - p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]), - total=len(self.items), - ) - ) - for idx, p in enumerate(phonemes): - self.items[idx][0] = p - - def sort_and_filter_items(self, by_audio_len=False): - r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length - range. - - Args: - by_audio_len (bool): if True, sort by audio length else by text length. - """ - # compute the target sequence length - if by_audio_len: - lengths = [] - for item in self.items: - lengths.append(os.path.getsize(item["audio_file"]) / 16 * 8) # assuming 16bit audio - lengths = np.array(lengths) - else: - lengths = np.array([len(ins["text"]) for ins in self.items]) - - idxs = np.argsort(lengths) - new_items = [] - ignored = [] + @staticmethod + def sort_and_filter_by_length(lengths:List[int], min_len:int, max_len:int): + idxs = np.argsort(lengths) # ascending order + ignore_idx = [] + keep_idx = [] for i, idx in enumerate(idxs): length = lengths[idx] - if length < self.min_seq_len or length > self.max_seq_len: - ignored.append(idx) + if length < min_len or length > max_len: + ignore_idx.append(idx) else: - new_items.append(self.items[idx]) + keep_idx.append(idx) + return ignore_idx, keep_idx + + @staticmethod + def create_buckets(samples, batch_group_size:int): + for i in range(len(samples) // batch_group_size): + offset = i * batch_group_size + end_offset = offset + batch_group_size + temp_items = samples[offset:end_offset] + random.shuffle(temp_items) + samples[offset:end_offset] = temp_items + return samples + + def preprocess_samples(self): + r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length + range. + """ + + # sort items based on the sequence length in ascending order + text_ignore_idx, text_keep_idx = self.sort_and_filter_by_length(self.text_lengths, self.min_text_len, self.max_text_len) + audio_ignore_idx, audio_keep_idx = self.sort_and_filter_by_length(self.audio_lengths, self.min_audio_len, self.max_audio_len) + keep_idx = list(set(audio_keep_idx) | set(text_keep_idx)) + ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) + + samples = [] + for idx in keep_idx: + samples.append(self.samples[idx]) + + if len(samples) == 0: + raise RuntimeError(" [!] No samples left") + # shuffle batch groups - if self.batch_group_size > 0: - for i in range(len(new_items) // self.batch_group_size): - offset = i * self.batch_group_size - end_offset = offset + self.batch_group_size - temp_items = new_items[offset:end_offset] - random.shuffle(temp_items) - new_items[offset:end_offset] = temp_items - self.items = new_items + # create batches with similar length items + # the larger the `batch_group_size`, the higher the length variety in a batch. + samples = self.create_buckets(samples, self.batch_group_size) + + # update items to the new sorted items + self.samples = samples if self.verbose: - print(" | > Max length sequence: {}".format(np.max(lengths))) - print(" | > Min length sequence: {}".format(np.min(lengths))) - print(" | > Avg length sequence: {}".format(np.mean(lengths))) - print( - " | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format( - self.max_seq_len, self.min_seq_len, len(ignored) - ) - ) + print(" | > Preprocessing samples") + print(" | > Max text length: {}".format(np.max(self.text_lengths))) + print(" | > Min text length: {}".format(np.min(self.text_lengths))) + print(" | > Avg text length: {}".format(np.mean(self.text_lengths))) + print(" | ") + print(" | > Max audio length: {}".format(np.max(self.audio_lengths))) + print(" | > Min audio length: {}".format(np.min(self.audio_lengths))) + print(" | > Avg audio length: {}".format(np.mean(self.audio_lengths))) + print(f" | > Num. instances discarded samples: {len(ignore_idx)}") print(" | > Batch group size: {}.".format(self.batch_group_size)) - def __len__(self): - return len(self.items) - - def __getitem__(self, idx): - return self.load_data(idx) - @staticmethod def _sort_batch(batch, text_lengths): """Sort the batch by the input text length for RNN efficiency. @@ -338,10 +342,10 @@ class TTSDataset(Dataset): # Puts each data field into a tensor with outer dimension batch size if isinstance(batch[0], collections.abc.Mapping): - text_lengths = np.array([len(d["text"]) for d in batch]) + token_ids_lengths = np.array([len(d["token_ids"]) for d in batch]) # sort items with text input length for RNN efficiency - batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths) + batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths) # convert list of dicts to dict of lists batch = {k: [dic[k] for dic in batch] for k in batch[0]} @@ -383,7 +387,7 @@ class TTSDataset(Dataset): stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) # PAD sequences with longest instance in the batch - text = prepare_data(batch["text"]).astype(np.int32) + text = prepare_data(batch["token_ids"]).astype(np.int32) # PAD features with longest instance mel = prepare_tensor(mel, self.outputs_per_step) @@ -392,12 +396,13 @@ class TTSDataset(Dataset): mel = mel.transpose(0, 2, 1) # convert things to pytorch - text_lengths = torch.LongTensor(text_lengths) + token_ids_lengths = torch.LongTensor(token_ids_lengths) text = torch.LongTensor(text) mel = torch.FloatTensor(mel).contiguous() mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) + # speaker vectors if d_vectors is not None: d_vectors = torch.FloatTensor(d_vectors) @@ -408,14 +413,13 @@ class TTSDataset(Dataset): language_ids = torch.LongTensor(language_ids) # compute linear spectrogram + linear = None if self.compute_linear_spec: linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] linear = prepare_tensor(linear, self.outputs_per_step) linear = linear.transpose(0, 2, 1) assert mel.shape[1] == linear.shape[1] linear = torch.FloatTensor(linear).contiguous() - else: - linear = None # format waveforms wav_padded = None @@ -431,8 +435,7 @@ class TTSDataset(Dataset): wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) wav_padded.transpose_(1, 2) - # compute f0 - # TODO: compare perf in collate_fn vs in load_data + # format F0 if self.compute_f0: pitch = prepare_data(batch["pitch"]) assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" @@ -440,7 +443,8 @@ class TTSDataset(Dataset): else: pitch = None - # collate attention alignments + # format attention masks + attns = None if batch["attn"][0] is not None: attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] for idx, attn in enumerate(attns): @@ -451,12 +455,10 @@ class TTSDataset(Dataset): attns[idx] = attn attns = prepare_tensor(attns, self.outputs_per_step) attns = torch.FloatTensor(attns).unsqueeze(1) - else: - attns = None - # TODO: return dictionary + return { - "text": text, - "text_lengths": text_lengths, + "token_id": text, + "token_id_lengths": token_ids_lengths, "speaker_names": batch["speaker_name"], "linear": linear, "mel": mel, @@ -482,22 +484,179 @@ class TTSDataset(Dataset): ) -class PitchExtractor: - """Pitch Extractor for computing F0 from wav files. +class PhonemeDataset(Dataset): + """Phoneme Dataset for converting input text to phonemes and then token IDs + + At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data + loading latency. If `cache_path` is already present, it skips the pre-computation. + Args: - items (List[List]): Dataset samples. - verbose (bool): Whether to print the progress. + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + tokenizer (TTSTokenizer): + Tokenizer to convert input text to phonemes. + + cache_path (str): + Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation. + + precompute_num_workers (int): + Number of workers used for pre-computing the phonemes. Defaults to 0. """ def __init__( self, - items: List[Dict], - verbose=False, + samples: Union[List[Dict], List[List]], + tokenizer: "TTSTokenizer", + cache_path: str, + precompute_num_workers=0, ): - self.items = items + self.samples = samples + self.tokenizer = tokenizer + self.cache_path = cache_path + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + + def __getitem__(self, index): + text, wav_file, *_ = _parse_sample(self.samples[index]) + ids = self.compute_or_load(wav_file, text) + ph_hat = self.tokenizer.ids_to_text(ids) + return {"text": text, "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} + + def __len__(self): + return len(self.samples) + + def compute_or_load(self, wav_file, text): + """Compute phonemes for the given text. + + If the phonemes are already cached, load them from cache. + """ + file_name = os.path.splitext(os.path.basename(wav_file))[0] + file_ext = "_phoneme.npy" + cache_path = os.path.join(self.cache_path, file_name + file_ext) + try: + ids = np.load(cache_path) + except FileNotFoundError: + ids = self.tokenizer.text_to_ids(text) + np.save(cache_path, ids) + return ids + + def get_pad_id(self): + """Get pad token ID for sequence padding""" + return self.tokenizer.pad_id + + def precompute(self, num_workers=1): + """Precompute phonemes for all samples. + + We use pytorch dataloader because we are lazy. + """ + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + ) + for _ in dataloder: + pbar.update(batch_size) + + def collate_fn(self, batch): + ids = [item["token_ids"] for item in batch] + ids_lens = [item["token_ids_len"] for item in batch] + texts = [item["text"] for item in batch] + texts_hat = [item["ph_hat"] for item in batch] + ids_lens_max = max(ids_lens) + ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id()) + for i, ids_len in enumerate(ids_lens): + ids_torch[i, :ids_len] = torch.LongTensor(ids[i]) + return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch} + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> PhonemeDataset ") + print(f"{indent}| > Tokenizer:") + self.tokenizer.print_logs(level + 1) + print(f"{indent}| > Number of instances : {len(self.samples)}") + + +class F0Dataset: + """F0 Dataset for computing F0 from wav files in CPU + + Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It + also computes the mean and std of F0 values if `normalize_f0` is True. + + Args: + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + ap (AudioProcessor): + AudioProcessor to compute F0 from wav files. + + cache_path (str): + Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation. + Defaults to None. + + precompute_num_workers (int): + Number of workers used for pre-computing the F0 values. Defaults to 0. + + normalize_f0 (bool): + Whether to normalize F0 values by mean and std. Defaults to True. + """ + + def __init__( + self, + samples: Union[List[List], List[Dict]], + ap: "AudioProcessor", + verbose=False, + cache_path: str = None, + precompute_num_workers=0, + normalize_f0=True, + ): + self.samples = samples + self.ap = ap self.verbose = verbose + self.cache_path = cache_path + self.normalize_f0 = normalize_f0 + self.pad_id = 0.0 self.mean = None self.std = None + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + if normalize_f0: + self.load_stats(cache_path) + + def __getitem__(self, idx): + _, wav_file, *_ = _parse_sample(self.samples[idx]) + f0 = self.compute_or_load(wav_file) + if self.normalize_f0: + assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" + f0 = self.normalize(f0) + return {"audio_file": wav_file, "f0": f0} + + def __len__(self): + return len(self.samples) + + def precompute(self, num_workers=0): + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + ) + computed_data = [] + for batch in dataloder: + f0 = batch["f0"] + computed_data.append([f for f in f0]) + pbar.update(batch_size) + + if self.normalize_f0: + computed_data = [tensor for batch in computed_data for tensor in batch] # flatten + pitch_mean, pitch_std = self.compute_pitch_stats(computed_data) + pitch_stats = {"mean": pitch_mean, "std": pitch_std} + np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) + + def get_pad_id(self): + return self.pad_id @staticmethod def create_pitch_file_path(wav_file, cache_path): @@ -519,69 +678,128 @@ class PitchExtractor: mean, std = np.mean(nonzeros), np.std(nonzeros) return mean, std - def normalize_pitch(self, pitch): + def load_stats(self, cache_path): + stats_path = os.path.join(cache_path, "pitch_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"].astype(np.float32) + self.std = stats["std"].astype(np.float32) + + def normalize(self, pitch): zero_idxs = np.where(pitch == 0.0)[0] pitch = pitch - self.mean pitch = pitch / self.std pitch[zero_idxs] = 0.0 return pitch - def denormalize_pitch(self, pitch): + def denormalize(self, pitch): zero_idxs = np.where(pitch == 0.0)[0] pitch *= self.std pitch += self.mean pitch[zero_idxs] = 0.0 return pitch - @staticmethod - def load_or_compute_pitch(ap, wav_file, cache_path): + def compute_or_load(self, wav_file): """ compute pitch and return a numpy array of pitch values """ - pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) + pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) if not os.path.exists(pitch_file): - pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) + pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) else: pitch = np.load(pitch_file) return pitch.astype(np.float32) - @staticmethod - def _pitch_worker(args): - item = args[0] - ap = args[1] - cache_path = args[2] - pitch_file = PitchExtractor.create_pitch_file_path(item["audio_file"], cache_path) - if not os.path.exists(pitch_file): - pitch = PitchExtractor._compute_and_save_pitch(ap, item["audio_file"], pitch_file) - return pitch - return None + def collate_fn(self, batch): + audio_file = [item["audio_file"] for item in batch] + f0s = [item["f0"] for item in batch] + f0_lens = [len(item["f0"]) for item in batch] + f0_lens_max = max(f0_lens) + f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id()) + for i, f0_len in enumerate(f0_lens): + f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i]) + return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens} - def compute_pitch(self, ap, cache_path, num_workers=0): - """Compute the input sequences with multi-processing. - Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" - if not os.path.exists(cache_path): - os.makedirs(cache_path, exist_ok=True) + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> F0Dataset ") + print(f"{indent}| > Number of instances : {len(self.samples)}") - if self.verbose: - print(" | > Computing pitch features ...") - if num_workers == 0: - pitch_vecs = [] - for _, item in enumerate(tqdm.tqdm(self.items)): - pitch_vecs += [self._pitch_worker([item, ap, cache_path])] - else: - with Pool(num_workers) as p: - pitch_vecs = list( - tqdm.tqdm( - p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]), - total=len(self.items), - ) - ) - pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs) - pitch_stats = {"mean": pitch_mean, "std": pitch_std} - np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) - def load_pitch_stats(self, cache_path): - stats_path = os.path.join(cache_path, "pitch_stats.npy") - stats = np.load(stats_path, allow_pickle=True).item() - self.mean = stats["mean"].astype(np.float32) - self.std = stats["std"].astype(np.float32) +if __name__ == "__main__": + from torch.utils.data import DataLoader + + from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig + from TTS.tts.datasets import load_tts_samples + from TTS.tts.utils.text.characters import IPAPhonemes + from TTS.tts.utils.text.phonemizers import ESpeak + + dataset_config = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + path="/Users/erengolge/Projects/TTS/recipes/ljspeech/LJSpeech-1.1", + ) + train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) + samples = train_samples + eval_samples + + phonemizer = ESpeak(language="en-us") + tokenizer = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=phonemizer) + # ph_dataset = PhonemeDataset(samples, tokenizer, phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests") + # ph_dataset.precompute(num_workers=4) + + # dataloader = DataLoader(ph_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=ph_dataset.collate_fn) + # for batch in dataloader: + # print(batch) + # break + + audio_config = BaseAudioConfig( + sample_rate=22050, + win_length=1024, + hop_length=256, + num_mels=80, + preemphasis=0.0, + ref_level_db=20, + log_func="np.log", + do_trim_silence=True, + trim_db=45, + mel_fmin=0, + mel_fmax=8000, + spec_gain=1.0, + signal_norm=False, + do_amp_to_db_linear=False, + ) + + ap = AudioProcessor.init_from_config(audio_config) + + # f0_dataset = F0Dataset(samples, ap, cache_path="/Users/erengolge/Projects/TTS/f0_tests", verbose=False, precompute_num_workers=4) + + # dataloader = DataLoader(f0_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=f0_dataset.collate_fn) + # for batch in dataloader: + # print(batch) + # breakpoint() + # break + + dataset = TTSDataset( + outputs_per_step=1, + compute_linear_spec=False, + meta_data=samples, + ap=ap, + return_wav=False, + batch_group_size=0, + min_seq_len=0, + max_seq_len=500, + use_noise_augment=False, + verbose=True, + speaker_id_mapping=None, + d_vector_mapping=None, + compute_f0=True, + f0_cache_path="/Users/erengolge/Projects/TTS/f0_tests", + tokenizer=tokenizer, + phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests", + precompute_num_workers=4, + ) + + dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn) + for batch in dataloader: + print(batch) + break