From 7b2804cc0d02f2ffa9012ae6271f81205fc2d55c Mon Sep 17 00:00:00 2001 From: Thomas Werkmeister Date: Mon, 29 Apr 2019 11:07:04 +0200 Subject: [PATCH] dropped dataset caching --- README.md | 4 +- config.json | 2 +- config_cluster.json | 2 +- datasets/TTSDataset.py | 23 ++------ datasets/preprocess.py | 12 ---- extract_features.py | 126 ----------------------------------------- tests/loader_tests.py | 18 +++--- tests/test_config.json | 1 - train.py | 1 - 9 files changed, 15 insertions(+), 174 deletions(-) delete mode 100644 extract_features.py diff --git a/README.md b/README.md index 24d0e0f8..c8953c18 100644 --- a/README.md +++ b/README.md @@ -69,9 +69,7 @@ Audio length is approximately 6 secs. ## Datasets and Data-Loading -TTS provides a generic dataloder easy to use for new datasets. You need to write an adaptor to format and that's all you need.Check ```datasets/preprocess.py``` to see example adaptors. After you wrote an adaptor, you need to set ```dataset``` field in ```config.json```. Do not forget other data related fields. - -You can also use pre-computed features. In this case, compute features with ```extract_features.py``` and set ```dataset``` field as ```tts_cache```. +TTS provides a generic dataloder easy to use for new datasets. You need to write an adaptor to format and that's all you need.Check ```datasets/preprocess.py``` to see example adaptors. After you wrote an adaptor, you need to set ```dataset``` field in ```config.json```. Do not forget other data related fields. Example datasets, we successfully applied TTS, are linked below. diff --git a/config.json b/config.json index 38ca74e1..c0222bb1 100644 --- a/config.json +++ b/config.json @@ -62,7 +62,7 @@ "data_path": "/home/erogol/Data/LJSpeech-1.1", // DATASET-RELATED: can overwritten from command argument "meta_file_train": "metadata_train.csv", // DATASET-RELATED: metafile for training dataloader. "meta_file_val": "metadata_val.csv", // DATASET-RELATED: metafile for evaluation dataloader. - "dataset": "ljspeech", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py + "dataset": "ljspeech", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. "min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training "max_seq_len": 150, // DATASET-RELATED: maximum text length "output_path": "/media/erogol/data_ssd/Data/models/ljspeech_models/", // DATASET-RELATED: output path for all training outputs. diff --git a/config_cluster.json b/config_cluster.json index fe227a01..5a6dec30 100644 --- a/config_cluster.json +++ b/config_cluster.json @@ -61,7 +61,7 @@ "data_path": "/media/erogol/data_ssd/Data/LJSpeech-1.1", // DATASET-RELATED: can overwritten from command argument "meta_file_train": "prompts_train.data", // DATASET-RELATED: metafile for training dataloader. "meta_file_val": "prompts_val.data", // DATASET-RELATED: metafile for evaluation dataloader. - "dataset": "mozilla", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py + "dataset": "mozilla", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. "min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training "max_seq_len": 150, // DATASET-RELATED: maximum text length "output_path": "../keep/", // DATASET-RELATED: output path for all training outputs. diff --git a/datasets/TTSDataset.py b/datasets/TTSDataset.py index 16b016f4..7d5c902f 100644 --- a/datasets/TTSDataset.py +++ b/datasets/TTSDataset.py @@ -22,7 +22,6 @@ class MyDataset(Dataset): batch_group_size=0, min_seq_len=0, max_seq_len=float("inf"), - cached=False, use_phonemes=True, phoneme_cache_path=None, phoneme_language="en-us", @@ -61,7 +60,6 @@ class MyDataset(Dataset): self.min_seq_len = min_seq_len self.max_seq_len = max_seq_len self.ap = ap - self.cached = cached self.use_phonemes = use_phonemes self.phoneme_cache_path = phoneme_cache_path self.phoneme_language = phoneme_language @@ -113,23 +111,10 @@ class MyDataset(Dataset): return text def load_data(self, idx): - if self.cached: - wav_name = self.items[idx][1] - mel_name = self.items[idx][2] - linear_name = self.items[idx][3] - text = self.items[idx][0] - - if wav_name.split('.')[-1] == 'npy': - wav = self.load_np(wav_name) - else: - wav = np.asarray(self.load_wav(wav_name), dtype=np.float32) - mel = self.load_np(mel_name) - linear = self.load_np(linear_name) - else: - text, wav_file = self.items[idx] - wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) - mel = None - linear = None + text, wav_file = self.items[idx] + wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) + mel = None + linear = None if self.use_phonemes: text = self.load_phoneme_sequence(wav_file, text) diff --git a/datasets/preprocess.py b/datasets/preprocess.py index c498577e..a2cac332 100644 --- a/datasets/preprocess.py +++ b/datasets/preprocess.py @@ -1,18 +1,6 @@ import os -def tts_cache(root_path, meta_file): - """This format is set for the meta-file generated by extract_features.py""" - txt_file = os.path.join(root_path, meta_file) - items = [] - with open(txt_file, 'r', encoding='utf8') as f: - for line in f: - cols = line.split('| ') - # text, wav_full_path, mel_name, linear_name, wav_len, mel_len - items.append(cols) - return items - - def tweb(root_path, meta_file): """Normalize TWEB dataset. https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset diff --git a/extract_features.py b/extract_features.py deleted file mode 100644 index 3e617b3f..00000000 --- a/extract_features.py +++ /dev/null @@ -1,126 +0,0 @@ -''' -Extract spectrograms and save them to file for training -''' -import os -import sys -import time -import glob -import argparse -import librosa -import importlib -import numpy as np -import tqdm -from utils.generic_utils import load_config, copy_config_file -from utils.audio import AudioProcessor - -from multiprocessing import Pool - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--data_path', type=str, help='Data folder.') - parser.add_argument('--cache_path', type=str, help='Cache folder, place to output all the spectrogram files.') - parser.add_argument( - '--config', type=str, help='conf.json file for run settings.') - parser.add_argument( - "--num_proc", type=int, default=8, help="number of processes.") - parser.add_argument( - "--trim_silence", - type=bool, - default=False, - help="trim silence in the voice clip.") - parser.add_argument("--only_mel", type=bool, default=False, help="If True, only melsceptrogram is extracted.") - parser.add_argument("--dataset", type=str, help="Target dataset to be processed.") - parser.add_argument("--val_split", type=int, default=0, help="Number of instances for validation.") - parser.add_argument("--meta_file", type=str, help="Meta data file to be used for the dataset.") - parser.add_argument("--process_audio", type=bool, default=False, help="Preprocess audio files.") - args = parser.parse_args() - - DATA_PATH = args.data_path - CACHE_PATH = args.cache_path - CONFIG = load_config(args.config) - - # load the right preprocessor - preprocessor = importlib.import_module('datasets.preprocess') - preprocessor = getattr(preprocessor, args.dataset.lower()) - items = preprocessor(args.data_path, args.meta_file) - - print(" > Input path: ", DATA_PATH) - print(" > Cache path: ", CACHE_PATH) - - ap = AudioProcessor(**CONFIG.audio) - - - def extract_mel(item): - """ Compute spectrograms, length information """ - text = item[0] - file_path = item[1] - x = ap.load_wav(file_path, ap.sample_rate) - file_name = os.path.basename(file_path).replace(".wav", "") - mel_file = file_name + "_mel" - mel_path = os.path.join(CACHE_PATH, 'mel', mel_file) - mel = ap.melspectrogram(x.astype('float32')).astype('float32') - np.save(mel_path, mel, allow_pickle=False) - mel_len = mel.shape[1] - wav_len = x.shape[0] - output = [text, file_path, mel_path+".npy", str(wav_len), str(mel_len)] - if not args.only_mel: - linear_file = file_name + "_linear" - linear_path = os.path.join(CACHE_PATH, 'linear', linear_file) - linear = ap.spectrogram(x.astype('float32')).astype('float32') - linear_len = linear.shape[1] - np.save(linear_path, linear, allow_pickle=False) - output.insert(3, linear_path+".npy") - assert mel_len == linear_len - if args.process_audio: - audio_file = file_name + "_audio" - audio_path = os.path.join(CACHE_PATH, 'audio', audio_file) - np.save(audio_path, x, allow_pickle=False) - del output[0] - output.insert(1, audio_path+".npy") - return output - - - if __name__ == "__main__": - print(" > Number of files: %i" % (len(items))) - if not os.path.exists(CACHE_PATH): - os.makedirs(os.path.join(CACHE_PATH, 'mel')) - if not args.only_mel: - os.makedirs(os.path.join(CACHE_PATH, 'linear')) - if args.process_audio: - os.makedirs(os.path.join(CACHE_PATH, 'audio')) - print(" > A new folder created at {}".format(CACHE_PATH)) - - # Extract features - r = [] - if args.num_proc > 1: - print(" > Using {} processes.".format(args.num_proc)) - with Pool(args.num_proc) as p: - r = list( - tqdm.tqdm( - p.imap(extract_mel, items), - total=len(items))) - # r = list(p.imap(extract_mel, file_names)) - else: - print(" > Using single process run.") - for item in items: - print(" > ", item[1]) - r.append(extract_mel(item)) - - # Save meta data - if args.cache_path is not None: - file_path = os.path.join(CACHE_PATH, "tts_metadata_val.csv") - file = open(file_path, "w") - for line in r[:args.val_split]: - line = "| ".join(line) - file.write(line + '\n') - file.close() - - file_path = os.path.join(CACHE_PATH, "tts_metadata.csv") - file = open(file_path, "w") - for line in r[args.val_split:]: - line = "| ".join(line) - file.write(line + '\n') - file.close() - - # copy the used config file to output path for sanity - copy_config_file(args.config, CACHE_PATH) diff --git a/tests/loader_tests.py b/tests/loader_tests.py index a70cdfc3..0830cdc9 100644 --- a/tests/loader_tests.py +++ b/tests/loader_tests.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader from utils.generic_utils import load_config from utils.audio import AudioProcessor from datasets import TTSDataset -from datasets.preprocess import ljspeech, tts_cache +from datasets.preprocess import ljspeech file_path = os.path.dirname(os.path.realpath(__file__)) OUTPATH = os.path.join(file_path, "outputs/loader_tests/") @@ -16,15 +16,11 @@ c = load_config(os.path.join(file_path, 'test_config.json')) ok_ljspeech = os.path.exists(c.data_path) DATA_EXIST = True -CACHE_EXIST = True -if not os.path.exists(c.data_path_cache): - CACHE_EXIST = False - if not os.path.exists(c.data_path): DATA_EXIST = False print(" > Dynamic data loader test: {}".format(DATA_EXIST)) -print(" > Cache data loader test: {}".format(CACHE_EXIST)) + class TestTTSDataset(unittest.TestCase): def __init__(self, *args, **kwargs): @@ -126,8 +122,9 @@ class TestTTSDataset(unittest.TestCase): wav = self.ap.load_wav(item_idx[0]) mel = self.ap.melspectrogram(wav) mel_dl = mel_input[0].cpu().numpy() - assert ( - abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0 + assert (abs(mel.T).astype("float32") + - abs(mel_dl[:-1]) + ).sum() == 0 # check mel-spec correctness mel_spec = mel_input[0].cpu().numpy() @@ -139,7 +136,8 @@ class TestTTSDataset(unittest.TestCase): linear_spec = linear_input[0].cpu().numpy() wav = self.ap.inv_spectrogram(linear_spec.T) self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader.wav') - shutil.copy(item_idx[0], OUTPATH + '/linear_target_dataloader.wav') + shutil.copy(item_idx[0], + OUTPATH + '/linear_target_dataloader.wav') # check the last time step to be zero padded assert linear_input[0, -1].sum() == 0 @@ -192,4 +190,4 @@ class TestTTSDataset(unittest.TestCase): # check batch conditions assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 - assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 \ No newline at end of file + assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 diff --git a/tests/test_config.json b/tests/test_config.json index b4436572..82d9bdd3 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -36,7 +36,6 @@ "save_step": 200, "data_path": "/home/erogol/Data/LJSpeech-1.1/", - "data_path_cache": "/media/erogol/data_ssd/Data/Nancy/tts_cache/", "output_path": "result", "min_seq_len": 0, "max_seq_len": 300, diff --git a/train.py b/train.py index f9c80ebe..58b6ebd0 100644 --- a/train.py +++ b/train.py @@ -53,7 +53,6 @@ def setup_loader(is_val=False, verbose=False): batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=0 if is_val else c.min_seq_len, max_seq_len=float("inf") if is_val else c.max_seq_len, - cached=False if c.dataset != "tts_cache" else True, phoneme_cache_path=c.phoneme_cache_path, use_phonemes=c.use_phonemes, phoneme_language=c.phoneme_language,