From 648655fa0366917e078d5a52396c792578662a2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 3 Sep 2021 13:25:57 +0000 Subject: [PATCH] Add `PitchExtractor` and return dict by `collate` --- TTS/tts/datasets/TTSDataset.py | 228 +++++++++++++++++---------------- TTS/tts/models/base_tts.py | 32 ++--- TTS/tts/models/glow_tts.py | 4 +- 3 files changed, 138 insertions(+), 126 deletions(-) diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index f6bd7038..74cb8de1 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -130,6 +130,8 @@ class TTSDataset(Dataset): if use_phonemes and not os.path.isdir(phoneme_cache_path): os.makedirs(phoneme_cache_path, exist_ok=True) + if compute_f0: + self.pitch_extractor = PitchExtractor(self.items, verbose=verbose) if self.verbose: print("\n > DataLoader initialization") print(" | > Use phonemes: {}".format(self.use_phonemes)) @@ -247,8 +249,8 @@ class TTSDataset(Dataset): pitch = None if self.compute_f0: - pitch = self._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) - pitch = self.normalize_pitch(pitch) + pitch = self.pitch_extractor._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) + pitch = self.pitch_extractor.normalize_pitch(pitch) sample = { "raw_text": raw_text, @@ -317,96 +319,6 @@ class TTSDataset(Dataset): for idx, p in enumerate(phonemes): self.items[idx][0] = p - ################ - # Pitch Methods - ############### - # TODO: Refactor Pitch methods into a separate class - - @staticmethod - def create_pitch_file_path(wav_file, cache_path): - file_name = os.path.splitext(os.path.basename(wav_file))[0] - pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") - return pitch_file - - @staticmethod - def _compute_and_save_pitch(ap, wav_file, pitch_file=None): - wav = ap.load_wav(wav_file) - pitch = ap.compute_f0(wav) - if pitch_file: - np.save(pitch_file, pitch) - return pitch - - @staticmethod - def compute_pitch_stats(pitch_vecs): - nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) - mean, std = np.mean(nonzeros), np.std(nonzeros) - return mean, std - - def normalize_pitch(self, pitch): - zero_idxs = np.where(pitch == 0.0)[0] - pitch -= self.mean - pitch /= self.std - pitch[zero_idxs] = 0.0 - return pitch - - @staticmethod - def _load_or_compute_pitch(ap, wav_file, cache_path): - """ - compute pitch and return a numpy array of pitch values - """ - pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path) - if not os.path.exists(pitch_file): - pitch = TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file) - else: - pitch = np.load(pitch_file) - return pitch - - @staticmethod - def _pitch_worker(args): - item = args[0] - ap = args[1] - cache_path = args[2] - _, wav_file, *_ = item - pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path) - if not os.path.exists(pitch_file): - pitch = TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file) - return pitch - return None - - def compute_pitch(self, cache_path, num_workers=0): - """Compute the input sequences with multi-processing. - Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" - if not os.path.exists(cache_path): - os.makedirs(cache_path, exist_ok=True) - - if self.verbose: - print(" | > Computing pitch features ...") - if num_workers == 0: - pitch_vecs = [] - for _, item in enumerate(tqdm.tqdm(self.items)): - pitch_vecs += [self._pitch_worker([item, self.ap, cache_path])] - else: - with Pool(num_workers) as p: - pitch_vecs = list( - tqdm.tqdm( - p.imap(TTSDataset._pitch_worker, [[item, self.ap, cache_path] for item in self.items]), - total=len(self.items), - ) - ) - pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs) - pitch_stats = {"mean": pitch_mean, "std": pitch_std} - np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) - - def load_pitch_stats(self, cache_path): - stats_path = os.path.join(cache_path, "pitch_stats.npy") - stats = np.load(stats_path, allow_pickle=True).item() - self.mean = stats["mean"] - self.std = stats["std"] - - ################### - # End Pitch Methods - ################### - def sort_and_filter_items(self, by_audio_len=False): r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length range. @@ -588,22 +500,22 @@ class TTSDataset(Dataset): else: attns = None # TODO: return dictionary - return ( - text, - text_lenghts, - speaker_names, - linear, - mel, - mel_lengths, - stop_targets, - item_idxs, - d_vectors, - speaker_ids, - attns, - wav_padded, - raw_text, - pitch, - ) + return { + "text": text, + "text_lengths": text_lenghts, + "speaker_names": speaker_names, + "linear": linear, + "mel": mel, + "mel_lengths": mel_lengths, + "stop_targets": stop_targets, + "item_idxs": item_idxs, + "d_vectors": d_vectors, + "speaker_ids": speaker_ids, + "attns": attns, + "waveform": wav_padded, + "raw_text": raw_text, + "pitch": pitch, + } raise TypeError( ( @@ -613,3 +525,103 @@ class TTSDataset(Dataset): ) ) ) + + +class PitchExtractor: + """Pitch Extractor for computing F0 from wav files. + + Args: + items (List[List]): Dataset samples. + verbose (bool): Whether to print the progress. + """ + + def __init__( + self, + items: List[List], + verbose=False, + ): + self.items = items + self.verbose = verbose + self.mean = None + self.std = None + + @staticmethod + def create_pitch_file_path(wav_file, cache_path): + file_name = os.path.splitext(os.path.basename(wav_file))[0] + pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") + return pitch_file + + @staticmethod + def _compute_and_save_pitch(ap, wav_file, pitch_file=None): + wav = ap.load_wav(wav_file) + pitch = ap.compute_f0(wav) + if pitch_file: + np.save(pitch_file, pitch) + return pitch + + @staticmethod + def compute_pitch_stats(pitch_vecs): + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) + mean, std = np.mean(nonzeros), np.std(nonzeros) + return mean, std + + def normalize_pitch(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch -= self.mean + pitch /= self.std + pitch[zero_idxs] = 0.0 + return pitch + + @staticmethod + def _load_or_compute_pitch(ap, wav_file, cache_path): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) + if not os.path.exists(pitch_file): + pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) + else: + pitch = np.load(pitch_file) + return pitch + + @staticmethod + def _pitch_worker(args): + item = args[0] + ap = args[1] + cache_path = args[2] + _, wav_file, *_ = item + pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) + if not os.path.exists(pitch_file): + pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) + return pitch + return None + + def compute_pitch(self, cache_path, num_workers=0): + """Compute the input sequences with multi-processing. + Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" + if not os.path.exists(cache_path): + os.makedirs(cache_path, exist_ok=True) + + if self.verbose: + print(" | > Computing pitch features ...") + if num_workers == 0: + pitch_vecs = [] + for _, item in enumerate(tqdm.tqdm(self.items)): + pitch_vecs += [self._pitch_worker([item, self.ap, cache_path])] + else: + with Pool(num_workers) as p: + pitch_vecs = list( + tqdm.tqdm( + p.imap(PitchExtractor._pitch_worker, [[item, self.ap, cache_path] for item in self.items]), + total=len(self.items), + ) + ) + pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs) + pitch_stats = {"mean": pitch_mean, "std": pitch_std} + np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) + + def load_pitch_stats(self, cache_path): + stats_path = os.path.join(cache_path, "pitch_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"] + self.std = stats["std"] diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 9e0bf41e..653143cd 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -104,19 +104,19 @@ class BaseTTS(BaseModel): Dict: [description] """ # setup input batch - text_input = batch[0] - text_lengths = batch[1] - speaker_names = batch[2] - linear_input = batch[3] - mel_input = batch[4] - mel_lengths = batch[5] - stop_targets = batch[6] - item_idx = batch[7] - d_vectors = batch[8] - speaker_ids = batch[9] - attn_mask = batch[10] - waveform = batch[11] - pitch = batch[13] + text_input = batch["text"] + text_lengths = batch["text_lengths"] + speaker_names = batch["speaker_names"] + linear_input = batch["linear"] + mel_input = batch["mel"] + mel_lengths = batch["mel_lengths"] + stop_targets = batch["stop_targets"] + item_idx = batch["item_idxs"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + attn_mask = batch["attns"] + waveform = batch["waveform"] + pitch = batch["pitch"] max_text_length = torch.max(text_lengths.float()) max_spec_length = torch.max(mel_lengths.float()) @@ -201,7 +201,7 @@ class BaseTTS(BaseModel): outputs_per_step=config.r if "r" in config else 1, text_cleaner=config.text_cleaner, compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, - comnpute_f0=config.get("compute_f0", False), + compute_f0=config.get("compute_f0", False), f0_cache_path=config.get("f0_cache_path", None), meta_data=data_items, ap=ap, @@ -252,8 +252,8 @@ class BaseTTS(BaseModel): # compute pitch frames and write to files. if config.compute_f0 and rank in [None, 0]: if not os.path.exists(config.f0_cache_path): - dataset.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers) - dataset.load_pitch_stats(config.get("f0_cache_path", None)) + dataset.pitch_extractor.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers) + dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None)) # halt DDP processes for the main process to finish computing the F0 cache if num_gpus > 1: diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index e6541871..27012207 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -134,9 +134,9 @@ class GlowTTS(BaseTTS): """ Shapes: - x: :math:`[B, T]` - - x_lenghts::math:` B` + - x_lenghts::math:`B` - y: :math:`[B, T, C]` - - y_lengths::math:` B` + - y_lengths::math:`B` - g: :math:`[B, C] or B` """ y = y.transpose(1, 2)