diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 78c6c33d..ccfa70f1 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -56,10 +56,6 @@ class TTSDataset(Dataset): meta_data (list): List of dataset instances. - compute_f0 (bool): compute f0 if True. Defaults to False. - - f0_cache_path (str): Path to store f0 cache. Defaults to None. - characters (dict): `dict` of custom text characters used for converting texts to sequences. custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own @@ -109,8 +105,6 @@ class TTSDataset(Dataset): self.cleaners = text_cleaner self.compute_linear_spec = compute_linear_spec self.return_wav = return_wav - self.compute_f0 = compute_f0 - self.f0_cache_path = f0_cache_path self.min_seq_len = min_seq_len self.max_seq_len = max_seq_len self.ap = ap @@ -339,7 +333,6 @@ class TTSDataset(Dataset): else: lengths = np.array([len(ins[0]) for ins in self.items]) - # sort items based on the sequence length in ascending order idxs = np.argsort(lengths) new_items = [] ignored = [] @@ -349,10 +342,7 @@ class TTSDataset(Dataset): ignored.append(idx) else: new_items.append(self.items[idx]) - # shuffle batch groups - # create batches with similar length items - # the larger the `batch_group_size`, the higher the length variety in a batch. if self.batch_group_size > 0: for i in range(len(new_items) // self.batch_group_size): offset = i * self.batch_group_size @@ -360,14 +350,8 @@ class TTSDataset(Dataset): temp_items = new_items[offset:end_offset] random.shuffle(temp_items) new_items[offset:end_offset] = temp_items - - if len(new_items) == 0: - raise RuntimeError(" [!] No items left after filtering.") - - # update items to the new sorted items self.items = new_items - # logging if self.verbose: print(" | > Max length sequence: {}".format(np.max(lengths))) print(" | > Min length sequence: {}".format(np.min(lengths))) @@ -554,110 +538,3 @@ class TTSDataset(Dataset): ) ) ) - - -class PitchExtractor: - """Pitch Extractor for computing F0 from wav files. - - Args: - items (List[List]): Dataset samples. - verbose (bool): Whether to print the progress. - """ - - def __init__( - self, - items: List[List], - verbose=False, - ): - self.items = items - self.verbose = verbose - self.mean = None - self.std = None - - @staticmethod - def create_pitch_file_path(wav_file, cache_path): - file_name = os.path.splitext(os.path.basename(wav_file))[0] - pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") - return pitch_file - - @staticmethod - def _compute_and_save_pitch(ap, wav_file, pitch_file=None): - wav = ap.load_wav(wav_file) - pitch = ap.compute_f0(wav) - if pitch_file: - np.save(pitch_file, pitch) - return pitch - - @staticmethod - def compute_pitch_stats(pitch_vecs): - nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) - mean, std = np.mean(nonzeros), np.std(nonzeros) - return mean, std - - def normalize_pitch(self, pitch): - zero_idxs = np.where(pitch == 0.0)[0] - pitch = pitch - self.mean - pitch = pitch / self.std - pitch[zero_idxs] = 0.0 - return pitch - - def denormalize_pitch(self, pitch): - zero_idxs = np.where(pitch == 0.0)[0] - pitch *= self.std - pitch += self.mean - pitch[zero_idxs] = 0.0 - return pitch - - @staticmethod - def load_or_compute_pitch(ap, wav_file, cache_path): - """ - compute pitch and return a numpy array of pitch values - """ - pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) - if not os.path.exists(pitch_file): - pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) - else: - pitch = np.load(pitch_file) - return pitch.astype(np.float32) - - @staticmethod - def _pitch_worker(args): - item = args[0] - ap = args[1] - cache_path = args[2] - _, wav_file, *_ = item - pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) - if not os.path.exists(pitch_file): - pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) - return pitch - return None - - def compute_pitch(self, ap, cache_path, num_workers=0): - """Compute the input sequences with multi-processing. - Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" - if not os.path.exists(cache_path): - os.makedirs(cache_path, exist_ok=True) - - if self.verbose: - print(" | > Computing pitch features ...") - if num_workers == 0: - pitch_vecs = [] - for _, item in enumerate(tqdm.tqdm(self.items)): - pitch_vecs += [self._pitch_worker([item, ap, cache_path])] - else: - with Pool(num_workers) as p: - pitch_vecs = list( - tqdm.tqdm( - p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]), - total=len(self.items), - ) - ) - pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs) - pitch_stats = {"mean": pitch_mean, "std": pitch_std} - np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) - - def load_pitch_stats(self, cache_path): - stats_path = os.path.join(cache_path, "pitch_stats.npy") - stats = np.load(stats_path, allow_pickle=True).item() - self.mean = stats["mean"].astype(np.float32) - self.std = stats["std"].astype(np.float32)