Update `vocoder` datasets and `setup_dataset`

This commit is contained in:
Eren Gölge 2021-06-18 13:42:36 +02:00
parent 59abf490a1
commit aed919cf1c
4 changed files with 89 additions and 14 deletions

View File

@ -0,0 +1,57 @@
from typing import List
from coqpit import Coqpit
from torch.utils.data import Dataset
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset:
if config.model.lower() in "gan":
dataset = GANDataset(
ap=ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad_short=config.pad_short,
conv_pad=config.conv_pad,
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
is_training=not is_eval,
return_segments=not is_eval,
use_noise_augment=config.use_noise_augment,
use_cache=config.use_cache,
verbose=verbose,
)
dataset.shuffle_mapping()
elif config.model.lower() == "wavegrad":
dataset = WaveGradDataset(
ap=ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad_short=config.pad_short,
conv_pad=config.conv_pad,
is_training=not is_eval,
return_segments=True,
use_noise_augment=False,
use_cache=config.use_cache,
verbose=verbose,
)
elif config.model.lower() == "wavernn":
dataset = WaveRNNDataset(
ap=ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad=config.model_params.pad,
mode=config.model_params.mode,
mulaw=config.model_params.mulaw,
is_training=not is_eval,
verbose=verbose,
)
else:
raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.")
return dataset

View File

@ -3,10 +3,21 @@ import os
from pathlib import Path
import numpy as np
from coqpit import Coqpit
from tqdm import tqdm
from TTS.utils.audio import AudioProcessor
def preprocess_wav_files(out_path, config, ap):
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
"""Process wav and compute mel and quantized wave signal.
It is mainly used by WaveRNN dataloader.
Args:
out_path (str): Parent folder path to save the files.
config (Coqpit): Model config.
ap (AudioProcessor): Audio processor.
"""
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
wav_files = find_wav_files(config.data_path)
@ -18,7 +29,9 @@ def preprocess_wav_files(out_path, config, ap):
mel = ap.melspectrogram(y)
np.save(mel_path, mel)
if isinstance(config.mode, int):
quant = ap.mulaw_encode(y, qc=config.mode) if config.mulaw else ap.quantize(y, bits=config.mode)
quant = (
ap.mulaw_encode(y, qc=config.mode) if config.model_params.mulaw else ap.quantize(y, bits=config.mode)
)
np.save(quant_path, quant)

View File

@ -136,4 +136,4 @@ class WaveGradDataset(Dataset):
mels[idx, :, : mel.shape[1]] = mel
audios[idx, : audio.shape[0]] = audio
return mels, audios
return audios, mels

View File

@ -10,16 +10,7 @@ class WaveRNNDataset(Dataset):
"""
def __init__(
self,
ap,
items,
seq_len,
hop_len,
pad,
mode,
mulaw,
is_training=True,
verbose=False,
self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True
):
super().__init__()
@ -34,6 +25,7 @@ class WaveRNNDataset(Dataset):
self.mulaw = mulaw
self.is_training = is_training
self.verbose = verbose
self.return_segments = return_segments
assert self.seq_len % self.hop_len == 0
@ -44,6 +36,16 @@ class WaveRNNDataset(Dataset):
item = self.load_item(index)
return item
def load_test_samples(self, num_samples):
samples = []
return_segments = self.return_segments
self.return_segments = False
for idx in range(num_samples):
mel, audio, _ = self.load_item(idx)
samples.append([mel, audio])
self.return_segments = return_segments
return samples
def load_item(self, index):
"""
load (audio, feat) couple if feature_path is set
@ -53,7 +55,10 @@ class WaveRNNDataset(Dataset):
wavpath = self.item_list[index]
audio = self.ap.load_wav(wavpath)
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
if self.return_segments:
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
else:
min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len)
if audio.shape[0] < min_audio_len:
print(" [!] Instance is too short! : {}".format(wavpath))
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])