From d085642ac1b13d296731b1a84f6e3b06f4888f32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 12 Jul 2021 18:23:19 +0200 Subject: [PATCH] Cache pitch features Cache the features at the beginning of `BaseTTS` training. --- TTS/tts/datasets/TTSDataset.py | 95 +++++++++++++++++++++++++++++----- TTS/tts/models/base_tts.py | 11 ++++ TTS/utils/audio.py | 57 ++++++++++++++++---- 3 files changed, 140 insertions(+), 23 deletions(-) diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 7ad52797..9b841034 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -9,7 +9,7 @@ import torch import tqdm from torch.utils.data import Dataset -from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor +from TTS.tts.utils.data import _pad_data, prepare_data, prepare_stop_target, prepare_tensor from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence from TTS.utils.audio import AudioProcessor @@ -23,6 +23,7 @@ class TTSDataset(Dataset): ap: AudioProcessor, meta_data: List[List], compute_f0: bool = False, + f0_cache_path: str = None, characters: Dict = None, custom_symbols: List = None, add_blank: bool = False, @@ -41,8 +42,7 @@ class TTSDataset(Dataset): ): """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. - If you need something different, you can either override or create a new class as the dataset is - initialized by the model. + If you need something different, you can inherit and override. Args: outputs_per_step (int): Number of time frames predicted per step. @@ -57,6 +57,8 @@ class TTSDataset(Dataset): compute_f0 (bool): compute f0 if True. Defaults to False. + f0_cache_path (str): Path to store f0 cache. Defaults to None. + characters (dict): `dict` of custom text characters used for converting texts to sequences. custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own @@ -81,8 +83,8 @@ class TTSDataset(Dataset): use_phonemes (bool): If true, input text converted to phonemes. Defaults to false. - phoneme_cache_path (str): Path to cache phoneme features. It writes computed phonemes to files to use in - the coming iterations. Defaults to None. + phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a + separate file. Defaults to None. phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`. @@ -107,6 +109,7 @@ class TTSDataset(Dataset): self.compute_linear_spec = compute_linear_spec self.return_wav = return_wav self.compute_f0 = compute_f0 + self.f0_cache_path = f0_cache_path self.min_seq_len = min_seq_len self.max_seq_len = max_seq_len self.ap = ap @@ -123,6 +126,7 @@ class TTSDataset(Dataset): self.verbose = verbose self.input_seq_computed = False self.rescue_item_idx = 1 + self.pitch_computed = False if use_phonemes and not os.path.isdir(phoneme_cache_path): os.makedirs(phoneme_cache_path, exist_ok=True) if self.verbose: @@ -240,10 +244,14 @@ class TTSDataset(Dataset): # TODO: find a better fix return self.load_data(self.rescue_item_idx) + if self.compute_f0: + pitch = self._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) + sample = { "raw_text": raw_text, "text": text, "wav": wav, + "pitch": pitch, "attn": attn, "item_idx": self.items[idx][1], "speaker_name": speaker_name, @@ -260,8 +268,8 @@ class TTSDataset(Dataset): return phonemes def compute_input_seq(self, num_workers=0): - """compute input sequences separately. Call it before - passing dataset to data loader.""" + """Compute the input sequences with multi-processing. + Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" if not self.use_phonemes: if self.verbose: print(" | > Computing input sequences ...") @@ -306,6 +314,64 @@ class TTSDataset(Dataset): for idx, p in enumerate(phonemes): self.items[idx][0] = p + @staticmethod + def create_pitch_file_path(wav_file, cache_path): + file_name = os.path.splitext(os.path.basename(wav_file))[0] + pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") + return pitch_file + + @staticmethod + def _compute_and_save_pitch(ap, wav_file, pitch_file=None): + wav = ap.load_wav(wav_file) + pitch = ap.compute_f0(wav) + if pitch_file: + np.save(pitch_file, pitch) + return pitch + + @staticmethod + def _load_or_compute_pitch(ap, wav_file, cache_path): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path) + if not os.path.exists(pitch_file): + pitch = TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file) + else: + pitch = np.load(pitch_file) + return pitch + + @staticmethod + def _pitch_worker(args): + item = args[0] + ap = args[1] + cache_path = args[2] + _, wav_file, *_ = item + pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path) + if not os.path.exists(pitch_file): + TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file) + return True + return False + + def compute_pitch(self, cache_path, num_workers=0): + """Compute the input sequences with multi-processing. + Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" + if not os.path.exists(cache_path): + os.makedirs(cache_path, exist_ok=True) + + if self.verbose: + print(" | > Computing pitch features ...") + if num_workers == 0: + for idx, item in enumerate(tqdm.tqdm(self.items)): + self._pitch_worker([item, self.ap, cache_path]) + else: + with Pool(num_workers) as p: + _ = list( + tqdm.tqdm( + p.imap(TTSDataset._pitch_worker, [[item, self.ap, cache_path] for item in self.items]), + total=len(self.items), + ) + ) + def sort_and_filter_items(self, by_audio_len=False): r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length range. @@ -367,7 +433,7 @@ class TTSDataset(Dataset): r""" Perform preprocessing and create a final data batch: 1. Sort batch instances by text-length - 2. Convert Audio signal to Spectrograms. + 2. Convert Audio signal to features. 3. PAD sequences wrt r. 4. Load to Torch. """ @@ -466,11 +532,12 @@ class TTSDataset(Dataset): # TODO: compare perf in collate_fn vs in load_data pitch = None if self.compute_f0: - pitch = [self.ap.compute_f0(w).astype("float32") for w in wav] - pitch = prepare_tensor(pitch, self.outputs_per_step) - pitch = pitch.transpose(0, 2, 1) - assert mel.shape[1] == pitch.shape[1] - pitch = torch.FloatTensor(pitch).contiguous() + pitch = [b["pitch"] for b in batch] + pitch = prepare_data(pitch) + assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" + pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT + else: + pitch = None # collate attention alignments if batch[0]["attn"] is not None: @@ -478,6 +545,7 @@ class TTSDataset(Dataset): for idx, attn in enumerate(attns): pad2 = mel.shape[1] - attn.shape[1] pad1 = text.shape[1] - attn.shape[0] + assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}" attn = np.pad(attn, [[0, pad1], [0, pad2]]) attns[idx] = attn attns = prepare_tensor(attns, self.outputs_per_step) @@ -499,6 +567,7 @@ class TTSDataset(Dataset): attns, wav_padded, raw_text, + pitch, ) raise TypeError( diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index d39473c7..3a6957f3 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -116,6 +116,7 @@ class BaseTTS(BaseModel): speaker_ids = batch[9] attn_mask = batch[10] waveform = batch[11] + pitch = batch[13] max_text_length = torch.max(text_lengths.float()) max_spec_length = torch.max(mel_lengths.float()) @@ -162,6 +163,7 @@ class BaseTTS(BaseModel): "max_spec_length": float(max_spec_length), "item_idx": item_idx, "waveform": waveform, + "pitch": pitch, } def get_data_loader( @@ -200,6 +202,7 @@ class BaseTTS(BaseModel): text_cleaner=config.text_cleaner, compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, comnpute_f0=config.get("compute_f0", False), + f0_cache_path=config.get("f0_cache_path", None), meta_data=data_items, ap=ap, characters=config.characters, @@ -246,6 +249,14 @@ class BaseTTS(BaseModel): # sort input sequences from short to long dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False)) + # compute pitch frames and write to files. + if config.compute_f0 and not os.path.exists(config.f0_cache_path) and rank in [None, 0]: + dataset.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers) + + # halt DDP processes for the main process to finish computing the F0 cache + if num_gpus > 1: + dist.barrier() + # sampler for DDP sampler = DistributedSampler(dataset) if num_gpus > 1 else None diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index e027b060..3d45b325 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -9,8 +9,7 @@ import torch from torch import nn from TTS.tts.utils.data import StandardScaler - -# import pyworld as pw +from TTS.utils.yin import compute_yin class TorchSTFT(nn.Module): # pylint: disable=abstract-method @@ -648,15 +647,53 @@ class AudioProcessor(object): # frame_period=1000 * self.hop_length / self.sample_rate, # ) # f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) - # f0 = compute_yin(, self.sample_rate, self.hop_length, self.fft_size) - f0, _, _ = librosa.pyin( - x.astype(np.double), - fmin=65 if self.mel_fmin == 0 else self.mel_fmin, - fmax=self.mel_fmax, - frame_length=self.win_length, - sr=self.sample_rate, - fill_na=0.0, + f0, _, _, _ = compute_yin( + x, + self.sample_rate, + self.win_length, + self.hop_length, + 65 if self.mel_fmin == 0 else self.mel_fmin, + self.mel_fmax, ) + # import pyworld as pw + # f0, _ = pw.dio(x.astype(np.float64), self.sample_rate, + # frame_period=self.hop_length / self.sample_rate * 1000) + pad = int((self.win_length / self.hop_length) / 2) + f0 = [0.0] * pad + f0 + [0.0] * pad + f0 = np.array(f0, dtype=np.float32) + + # f01, _, _ = librosa.pyin( + # x, + # fmin=65 if self.mel_fmin == 0 else self.mel_fmin, + # fmax=self.mel_fmax, + # frame_length=self.win_length, + # sr=self.sample_rate, + # fill_na=0.0, + # ) + + # f02 = librosa.yin( + # x, + # fmin=65 if self.mel_fmin == 0 else self.mel_fmin, + # fmax=self.mel_fmax, + # frame_length=self.win_length, + # sr=self.sample_rate + # ) + + # spec = self.melspectrogram(x) + + # from matplotlib import pyplot as plt + # plt.figure() + # plt.plot(f0, linewidth=2.5, color='red') + # plt.plot(f01, linewidth=2.5, linestyle='-.') + # plt.plot(f02, linewidth=2.5) + # plt.xlabel('time') + # plt.ylabel('F0') + # plt.savefig('save_img.png') + + # # plt.figure() + # plt.imshow(spec, aspect="auto", origin="lower") + # plt.savefig('save_img2.png') + # breakpoint() return f0 ### Audio Processing ###