Update `vocoder` datasets and `setup_dataset`

This commit is contained in:
Eren Gölge 2021-06-18 13:42:36 +02:00
parent d18198dff8
commit d7225eedb0
4 changed files with 89 additions and 14 deletions

View File

@ -0,0 +1,57 @@
from typing import List
from coqpit import Coqpit
from torch.utils.data import Dataset
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset:
if config.model.lower() in "gan":
dataset = GANDataset(
ap=ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad_short=config.pad_short,
conv_pad=config.conv_pad,
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
is_training=not is_eval,
return_segments=not is_eval,
use_noise_augment=config.use_noise_augment,
use_cache=config.use_cache,
verbose=verbose,
)
dataset.shuffle_mapping()
elif config.model.lower() == "wavegrad":
dataset = WaveGradDataset(
ap=ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad_short=config.pad_short,
conv_pad=config.conv_pad,
is_training=not is_eval,
return_segments=True,
use_noise_augment=False,
use_cache=config.use_cache,
verbose=verbose,
)
elif config.model.lower() == "wavernn":
dataset = WaveRNNDataset(
ap=ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad=config.model_params.pad,
mode=config.model_params.mode,
mulaw=config.model_params.mulaw,
is_training=not is_eval,
verbose=verbose,
)
else:
raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.")
return dataset

View File

@ -3,10 +3,21 @@ import os
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
from coqpit import Coqpit
from tqdm import tqdm from tqdm import tqdm
from TTS.utils.audio import AudioProcessor
def preprocess_wav_files(out_path, config, ap):
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
"""Process wav and compute mel and quantized wave signal.
It is mainly used by WaveRNN dataloader.
Args:
out_path (str): Parent folder path to save the files.
config (Coqpit): Model config.
ap (AudioProcessor): Audio processor.
"""
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
wav_files = find_wav_files(config.data_path) wav_files = find_wav_files(config.data_path)
@ -18,7 +29,9 @@ def preprocess_wav_files(out_path, config, ap):
mel = ap.melspectrogram(y) mel = ap.melspectrogram(y)
np.save(mel_path, mel) np.save(mel_path, mel)
if isinstance(config.mode, int): if isinstance(config.mode, int):
quant = ap.mulaw_encode(y, qc=config.mode) if config.mulaw else ap.quantize(y, bits=config.mode) quant = (
ap.mulaw_encode(y, qc=config.mode) if config.model_params.mulaw else ap.quantize(y, bits=config.mode)
)
np.save(quant_path, quant) np.save(quant_path, quant)

View File

@ -136,4 +136,4 @@ class WaveGradDataset(Dataset):
mels[idx, :, : mel.shape[1]] = mel mels[idx, :, : mel.shape[1]] = mel
audios[idx, : audio.shape[0]] = audio audios[idx, : audio.shape[0]] = audio
return mels, audios return audios, mels

View File

@ -10,16 +10,7 @@ class WaveRNNDataset(Dataset):
""" """
def __init__( def __init__(
self, self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True
ap,
items,
seq_len,
hop_len,
pad,
mode,
mulaw,
is_training=True,
verbose=False,
): ):
super().__init__() super().__init__()
@ -34,6 +25,7 @@ class WaveRNNDataset(Dataset):
self.mulaw = mulaw self.mulaw = mulaw
self.is_training = is_training self.is_training = is_training
self.verbose = verbose self.verbose = verbose
self.return_segments = return_segments
assert self.seq_len % self.hop_len == 0 assert self.seq_len % self.hop_len == 0
@ -44,6 +36,16 @@ class WaveRNNDataset(Dataset):
item = self.load_item(index) item = self.load_item(index)
return item return item
def load_test_samples(self, num_samples):
samples = []
return_segments = self.return_segments
self.return_segments = False
for idx in range(num_samples):
mel, audio, _ = self.load_item(idx)
samples.append([mel, audio])
self.return_segments = return_segments
return samples
def load_item(self, index): def load_item(self, index):
""" """
load (audio, feat) couple if feature_path is set load (audio, feat) couple if feature_path is set
@ -53,7 +55,10 @@ class WaveRNNDataset(Dataset):
wavpath = self.item_list[index] wavpath = self.item_list[index]
audio = self.ap.load_wav(wavpath) audio = self.ap.load_wav(wavpath)
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len) if self.return_segments:
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
else:
min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len)
if audio.shape[0] < min_audio_len: if audio.shape[0] < min_audio_len:
print(" [!] Instance is too short! : {}".format(wavpath)) print(" [!] Instance is too short! : {}".format(wavpath))
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len]) audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])