From 1c2dc9f73909590aed91734cc1ea0a83ee1fb97c Mon Sep 17 00:00:00 2001 From: erogol Date: Fri, 12 Jun 2020 11:12:57 +0200 Subject: [PATCH] enable loading precomputed vocoder dataset --- vocoder/datasets/gan_dataset.py | 25 +++++++++++++++++++------ vocoder/datasets/preprocess.py | 21 +++++++++++++++++++++ vocoder/train.py | 4 ++-- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/vocoder/datasets/gan_dataset.py b/vocoder/datasets/gan_dataset.py index 5e410151..55513e7d 100644 --- a/vocoder/datasets/gan_dataset.py +++ b/vocoder/datasets/gan_dataset.py @@ -28,6 +28,7 @@ class GANDataset(Dataset): self.ap = ap self.item_list = items + self.compute_feat = not isinstance(items[0], (tuple, list)) self.seq_len = seq_len self.hop_len = hop_len self.pad_short = pad_short @@ -77,14 +78,26 @@ class GANDataset(Dataset): def load_item(self, idx): """ load (audio, feat) couple """ - wavpath = self.item_list[idx] - # print(wavpath) + if self.compute_feat: + # compute features from wav + wavpath = self.item_list[idx] + # print(wavpath) - if self.use_cache and self.cache[idx] is not None: - audio, mel = self.cache[idx] + if self.use_cache and self.cache[idx] is not None: + audio, mel = self.cache[idx] + else: + audio = self.ap.load_wav(wavpath) + mel = self.ap.melspectrogram(audio) else: - audio = self.ap.load_wav(wavpath) - mel = self.ap.melspectrogram(audio) + + # load precomputed features + wavpath, feat_path = self.item_list[idx] + + if self.use_cache and self.cache[idx] is not None: + audio, mel = self.cache[idx] + else: + audio = self.ap.load_wav(wavpath) + mel = np.load(feat_path) if len(audio) < self.seq_len + self.pad_short: audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \ diff --git a/vocoder/datasets/preprocess.py b/vocoder/datasets/preprocess.py index 01e01e3e..be60c13a 100644 --- a/vocoder/datasets/preprocess.py +++ b/vocoder/datasets/preprocess.py @@ -1,5 +1,6 @@ import glob import os +from pathlib import Path import numpy as np @@ -9,8 +10,28 @@ def find_wav_files(data_path): return wav_paths +def find_feat_files(data_path): + feat_paths = glob.glob(os.path.join(data_path, '**', '*.npy'), recursive=True) + return feat_paths + + def load_wav_data(data_path, eval_split_size): wav_paths = find_wav_files(data_path) np.random.seed(0) np.random.shuffle(wav_paths) return wav_paths[:eval_split_size], wav_paths[eval_split_size:] + + +def load_wav_feat_data(data_path, feat_path, eval_split_size): + wav_paths = sorted(find_wav_files(data_path)) + feat_paths = sorted(find_feat_files(feat_path)) + assert len(wav_paths) == len(feat_paths) + for wav, feat in zip(wav_paths, feat_paths): + wav_name = Path(wav).stem + feat_name = Path(feat).stem + assert wav_name == feat_name + + items = list(zip(wav_paths, feat_paths)) + np.random.seed(0) + np.random.shuffle(items) + return items[:eval_split_size], items[eval_split_size:] diff --git a/vocoder/train.py b/vocoder/train.py index 74b3e759..4d1b1029 100644 --- a/vocoder/train.py +++ b/vocoder/train.py @@ -19,7 +19,7 @@ from TTS.utils.radam import RAdam from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.training import setup_torch_training_env from TTS.vocoder.datasets.gan_dataset import GANDataset -from TTS.vocoder.datasets.preprocess import load_wav_data +from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data # from distribute import (DistributedSampler, apply_gradient_allreduce, # init_distributed, reduce_tensor) from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss @@ -543,8 +543,8 @@ def main(args): # pylint: disable=redefined-outer-name best_loss = save_best_model(target_loss, best_loss, model_gen, - scheduler_gen, optimizer_gen, + scheduler_gen, model_disc, optimizer_disc, scheduler_disc,