From e802b24ad0c5e4f10796ca2e68847c246fc9331a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 14 Jul 2021 14:33:57 +0200 Subject: [PATCH] Compute mean and std pitch --- TTS/tts/datasets/TTSDataset.py | 46 +++++++++++++++++++++++++++++----- TTS/tts/models/base_tts.py | 6 +++-- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 3533dede..f6bd7038 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -127,6 +127,7 @@ class TTSDataset(Dataset): self.input_seq_computed = False self.rescue_item_idx = 1 self.pitch_computed = False + if use_phonemes and not os.path.isdir(phoneme_cache_path): os.makedirs(phoneme_cache_path, exist_ok=True) if self.verbose: @@ -247,6 +248,7 @@ class TTSDataset(Dataset): pitch = None if self.compute_f0: pitch = self._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) + pitch = self.normalize_pitch(pitch) sample = { "raw_text": raw_text, @@ -315,6 +317,11 @@ class TTSDataset(Dataset): for idx, p in enumerate(phonemes): self.items[idx][0] = p + ################ + # Pitch Methods + ############### + # TODO: Refactor Pitch methods into a separate class + @staticmethod def create_pitch_file_path(wav_file, cache_path): file_name = os.path.splitext(os.path.basename(wav_file))[0] @@ -329,6 +336,19 @@ class TTSDataset(Dataset): np.save(pitch_file, pitch) return pitch + @staticmethod + def compute_pitch_stats(pitch_vecs): + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) + mean, std = np.mean(nonzeros), np.std(nonzeros) + return mean, std + + def normalize_pitch(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch -= self.mean + pitch /= self.std + pitch[zero_idxs] = 0.0 + return pitch + @staticmethod def _load_or_compute_pitch(ap, wav_file, cache_path): """ @@ -349,9 +369,9 @@ class TTSDataset(Dataset): _, wav_file, *_ = item pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path) if not os.path.exists(pitch_file): - TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file) - return True - return False + pitch = TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file) + return pitch + return None def compute_pitch(self, cache_path, num_workers=0): """Compute the input sequences with multi-processing. @@ -362,16 +382,30 @@ class TTSDataset(Dataset): if self.verbose: print(" | > Computing pitch features ...") if num_workers == 0: - for idx, item in enumerate(tqdm.tqdm(self.items)): - self._pitch_worker([item, self.ap, cache_path]) + pitch_vecs = [] + for _, item in enumerate(tqdm.tqdm(self.items)): + pitch_vecs += [self._pitch_worker([item, self.ap, cache_path])] else: with Pool(num_workers) as p: - _ = list( + pitch_vecs = list( tqdm.tqdm( p.imap(TTSDataset._pitch_worker, [[item, self.ap, cache_path] for item in self.items]), total=len(self.items), ) ) + pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs) + pitch_stats = {"mean": pitch_mean, "std": pitch_std} + np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) + + def load_pitch_stats(self, cache_path): + stats_path = os.path.join(cache_path, "pitch_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"] + self.std = stats["std"] + + ################### + # End Pitch Methods + ################### def sort_and_filter_items(self, by_audio_len=False): r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 3a6957f3..9e0bf41e 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -250,8 +250,10 @@ class BaseTTS(BaseModel): dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False)) # compute pitch frames and write to files. - if config.compute_f0 and not os.path.exists(config.f0_cache_path) and rank in [None, 0]: - dataset.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers) + if config.compute_f0 and rank in [None, 0]: + if not os.path.exists(config.f0_cache_path): + dataset.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers) + dataset.load_pitch_stats(config.get("f0_cache_path", None)) # halt DDP processes for the main process to finish computing the F0 cache if num_gpus > 1: