mirror of https://github.com/coqui-ai/TTS.git
Update `vocoder` datasets and `setup_dataset`
This commit is contained in:
parent
d18198dff8
commit
d7225eedb0
|
@ -0,0 +1,57 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from coqpit import Coqpit
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||||
|
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||||
|
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||||
|
|
||||||
|
|
||||||
|
def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset:
|
||||||
|
if config.model.lower() in "gan":
|
||||||
|
dataset = GANDataset(
|
||||||
|
ap=ap,
|
||||||
|
items=data_items,
|
||||||
|
seq_len=config.seq_len,
|
||||||
|
hop_len=ap.hop_length,
|
||||||
|
pad_short=config.pad_short,
|
||||||
|
conv_pad=config.conv_pad,
|
||||||
|
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
|
||||||
|
is_training=not is_eval,
|
||||||
|
return_segments=not is_eval,
|
||||||
|
use_noise_augment=config.use_noise_augment,
|
||||||
|
use_cache=config.use_cache,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
dataset.shuffle_mapping()
|
||||||
|
elif config.model.lower() == "wavegrad":
|
||||||
|
dataset = WaveGradDataset(
|
||||||
|
ap=ap,
|
||||||
|
items=data_items,
|
||||||
|
seq_len=config.seq_len,
|
||||||
|
hop_len=ap.hop_length,
|
||||||
|
pad_short=config.pad_short,
|
||||||
|
conv_pad=config.conv_pad,
|
||||||
|
is_training=not is_eval,
|
||||||
|
return_segments=True,
|
||||||
|
use_noise_augment=False,
|
||||||
|
use_cache=config.use_cache,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
elif config.model.lower() == "wavernn":
|
||||||
|
dataset = WaveRNNDataset(
|
||||||
|
ap=ap,
|
||||||
|
items=data_items,
|
||||||
|
seq_len=config.seq_len,
|
||||||
|
hop_len=ap.hop_length,
|
||||||
|
pad=config.model_params.pad,
|
||||||
|
mode=config.model_params.mode,
|
||||||
|
mulaw=config.model_params.mulaw,
|
||||||
|
is_training=not is_eval,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.")
|
||||||
|
return dataset
|
|
@ -3,10 +3,21 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from coqpit import Coqpit
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
def preprocess_wav_files(out_path, config, ap):
|
|
||||||
|
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
||||||
|
"""Process wav and compute mel and quantized wave signal.
|
||||||
|
It is mainly used by WaveRNN dataloader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_path (str): Parent folder path to save the files.
|
||||||
|
config (Coqpit): Model config.
|
||||||
|
ap (AudioProcessor): Audio processor.
|
||||||
|
"""
|
||||||
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
|
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
|
||||||
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
|
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
|
||||||
wav_files = find_wav_files(config.data_path)
|
wav_files = find_wav_files(config.data_path)
|
||||||
|
@ -18,7 +29,9 @@ def preprocess_wav_files(out_path, config, ap):
|
||||||
mel = ap.melspectrogram(y)
|
mel = ap.melspectrogram(y)
|
||||||
np.save(mel_path, mel)
|
np.save(mel_path, mel)
|
||||||
if isinstance(config.mode, int):
|
if isinstance(config.mode, int):
|
||||||
quant = ap.mulaw_encode(y, qc=config.mode) if config.mulaw else ap.quantize(y, bits=config.mode)
|
quant = (
|
||||||
|
ap.mulaw_encode(y, qc=config.mode) if config.model_params.mulaw else ap.quantize(y, bits=config.mode)
|
||||||
|
)
|
||||||
np.save(quant_path, quant)
|
np.save(quant_path, quant)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -136,4 +136,4 @@ class WaveGradDataset(Dataset):
|
||||||
mels[idx, :, : mel.shape[1]] = mel
|
mels[idx, :, : mel.shape[1]] = mel
|
||||||
audios[idx, : audio.shape[0]] = audio
|
audios[idx, : audio.shape[0]] = audio
|
||||||
|
|
||||||
return mels, audios
|
return audios, mels
|
||||||
|
|
|
@ -10,16 +10,7 @@ class WaveRNNDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True
|
||||||
ap,
|
|
||||||
items,
|
|
||||||
seq_len,
|
|
||||||
hop_len,
|
|
||||||
pad,
|
|
||||||
mode,
|
|
||||||
mulaw,
|
|
||||||
is_training=True,
|
|
||||||
verbose=False,
|
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -34,6 +25,7 @@ class WaveRNNDataset(Dataset):
|
||||||
self.mulaw = mulaw
|
self.mulaw = mulaw
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
self.return_segments = return_segments
|
||||||
|
|
||||||
assert self.seq_len % self.hop_len == 0
|
assert self.seq_len % self.hop_len == 0
|
||||||
|
|
||||||
|
@ -44,6 +36,16 @@ class WaveRNNDataset(Dataset):
|
||||||
item = self.load_item(index)
|
item = self.load_item(index)
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
def load_test_samples(self, num_samples):
|
||||||
|
samples = []
|
||||||
|
return_segments = self.return_segments
|
||||||
|
self.return_segments = False
|
||||||
|
for idx in range(num_samples):
|
||||||
|
mel, audio, _ = self.load_item(idx)
|
||||||
|
samples.append([mel, audio])
|
||||||
|
self.return_segments = return_segments
|
||||||
|
return samples
|
||||||
|
|
||||||
def load_item(self, index):
|
def load_item(self, index):
|
||||||
"""
|
"""
|
||||||
load (audio, feat) couple if feature_path is set
|
load (audio, feat) couple if feature_path is set
|
||||||
|
@ -53,7 +55,10 @@ class WaveRNNDataset(Dataset):
|
||||||
|
|
||||||
wavpath = self.item_list[index]
|
wavpath = self.item_list[index]
|
||||||
audio = self.ap.load_wav(wavpath)
|
audio = self.ap.load_wav(wavpath)
|
||||||
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
|
if self.return_segments:
|
||||||
|
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
|
||||||
|
else:
|
||||||
|
min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len)
|
||||||
if audio.shape[0] < min_audio_len:
|
if audio.shape[0] < min_audio_len:
|
||||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||||
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
|
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
|
||||||
|
|
Loading…
Reference in New Issue