From aed919cf1c895c51bdd10fc02a1388e909286e0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 18 Jun 2021 13:42:36 +0200 Subject: [PATCH] Update `vocoder` datasets and `setup_dataset` --- TTS/vocoder/datasets/__init__.py | 57 ++++++++++++++++++++++++ TTS/vocoder/datasets/preprocess.py | 17 ++++++- TTS/vocoder/datasets/wavegrad_dataset.py | 2 +- TTS/vocoder/datasets/wavernn_dataset.py | 27 ++++++----- 4 files changed, 89 insertions(+), 14 deletions(-) diff --git a/TTS/vocoder/datasets/__init__.py b/TTS/vocoder/datasets/__init__.py index e69de29b..86b059c3 100644 --- a/TTS/vocoder/datasets/__init__.py +++ b/TTS/vocoder/datasets/__init__.py @@ -0,0 +1,57 @@ +from typing import List + +from coqpit import Coqpit +from torch.utils.data import Dataset + +from TTS.utils.audio import AudioProcessor +from TTS.vocoder.datasets.gan_dataset import GANDataset +from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset +from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset + + +def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset: + if config.model.lower() in "gan": + dataset = GANDataset( + ap=ap, + items=data_items, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False, + is_training=not is_eval, + return_segments=not is_eval, + use_noise_augment=config.use_noise_augment, + use_cache=config.use_cache, + verbose=verbose, + ) + dataset.shuffle_mapping() + elif config.model.lower() == "wavegrad": + dataset = WaveGradDataset( + ap=ap, + items=data_items, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + is_training=not is_eval, + return_segments=True, + use_noise_augment=False, + use_cache=config.use_cache, + verbose=verbose, + ) + elif config.model.lower() == "wavernn": + dataset = WaveRNNDataset( + ap=ap, + items=data_items, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad=config.model_params.pad, + mode=config.model_params.mode, + mulaw=config.model_params.mulaw, + is_training=not is_eval, + verbose=verbose, + ) + else: + raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.") + return dataset diff --git a/TTS/vocoder/datasets/preprocess.py b/TTS/vocoder/datasets/preprocess.py index d99ee147..c4569b3d 100644 --- a/TTS/vocoder/datasets/preprocess.py +++ b/TTS/vocoder/datasets/preprocess.py @@ -3,10 +3,21 @@ import os from pathlib import Path import numpy as np +from coqpit import Coqpit from tqdm import tqdm +from TTS.utils.audio import AudioProcessor -def preprocess_wav_files(out_path, config, ap): + +def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor): + """Process wav and compute mel and quantized wave signal. + It is mainly used by WaveRNN dataloader. + + Args: + out_path (str): Parent folder path to save the files. + config (Coqpit): Model config. + ap (AudioProcessor): Audio processor. + """ os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) wav_files = find_wav_files(config.data_path) @@ -18,7 +29,9 @@ def preprocess_wav_files(out_path, config, ap): mel = ap.melspectrogram(y) np.save(mel_path, mel) if isinstance(config.mode, int): - quant = ap.mulaw_encode(y, qc=config.mode) if config.mulaw else ap.quantize(y, bits=config.mode) + quant = ( + ap.mulaw_encode(y, qc=config.mode) if config.model_params.mulaw else ap.quantize(y, bits=config.mode) + ) np.save(quant_path, quant) diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py index c0d24e84..d99fc417 100644 --- a/TTS/vocoder/datasets/wavegrad_dataset.py +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -136,4 +136,4 @@ class WaveGradDataset(Dataset): mels[idx, :, : mel.shape[1]] = mel audios[idx, : audio.shape[0]] = audio - return mels, audios + return audios, mels diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py index 1596ea8f..d648b68c 100644 --- a/TTS/vocoder/datasets/wavernn_dataset.py +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -10,16 +10,7 @@ class WaveRNNDataset(Dataset): """ def __init__( - self, - ap, - items, - seq_len, - hop_len, - pad, - mode, - mulaw, - is_training=True, - verbose=False, + self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True ): super().__init__() @@ -34,6 +25,7 @@ class WaveRNNDataset(Dataset): self.mulaw = mulaw self.is_training = is_training self.verbose = verbose + self.return_segments = return_segments assert self.seq_len % self.hop_len == 0 @@ -44,6 +36,16 @@ class WaveRNNDataset(Dataset): item = self.load_item(index) return item + def load_test_samples(self, num_samples): + samples = [] + return_segments = self.return_segments + self.return_segments = False + for idx in range(num_samples): + mel, audio, _ = self.load_item(idx) + samples.append([mel, audio]) + self.return_segments = return_segments + return samples + def load_item(self, index): """ load (audio, feat) couple if feature_path is set @@ -53,7 +55,10 @@ class WaveRNNDataset(Dataset): wavpath = self.item_list[index] audio = self.ap.load_wav(wavpath) - min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len) + if self.return_segments: + min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len) + else: + min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len) if audio.shape[0] < min_audio_len: print(" [!] Instance is too short! : {}".format(wavpath)) audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])