Merge branch 'drop_cached_dataset' of https://github.com/twerkmeister/TTS into twerkmeister-drop_cached_dataset

This commit is contained in:
Eren Golge 2019-05-14 22:37:39 +02:00
commit 51f71c068a
8 changed files with 16 additions and 187 deletions

View File

@ -76,8 +76,6 @@ Audio length is approximately 6 secs.
## Datasets and Data-Loading ## Datasets and Data-Loading
TTS provides a generic dataloder easy to use for new datasets. You need to write an adaptor to format and that's all you need.Check ```datasets/preprocess.py``` to see example adaptors. After you wrote an adaptor, you need to set ```dataset``` field in ```config.json```. Do not forget other data related fields. TTS provides a generic dataloder easy to use for new datasets. You need to write an adaptor to format and that's all you need.Check ```datasets/preprocess.py``` to see example adaptors. After you wrote an adaptor, you need to set ```dataset``` field in ```config.json```. Do not forget other data related fields.
You can also use pre-computed features. In this case, compute features with ```extract_features.py``` and set ```dataset``` field as ```tts_cache```.
Example datasets, we successfully applied TTS, are linked below. Example datasets, we successfully applied TTS, are linked below.
- [LJ Speech](https://keithito.com/LJ-Speech-Dataset/) - [LJ Speech](https://keithito.com/LJ-Speech-Dataset/)

View File

@ -62,7 +62,7 @@
"data_path": "/home/erogol/Data/LJSpeech-1.1", // DATASET-RELATED: can overwritten from command argument "data_path": "/home/erogol/Data/LJSpeech-1.1", // DATASET-RELATED: can overwritten from command argument
"meta_file_train": "metadata_train.csv", // DATASET-RELATED: metafile for training dataloader. "meta_file_train": "metadata_train.csv", // DATASET-RELATED: metafile for training dataloader.
"meta_file_val": "metadata_val.csv", // DATASET-RELATED: metafile for evaluation dataloader. "meta_file_val": "metadata_val.csv", // DATASET-RELATED: metafile for evaluation dataloader.
"dataset": "ljspeech", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py "dataset": "ljspeech", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset.
"min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training "min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training
"max_seq_len": 150, // DATASET-RELATED: maximum text length "max_seq_len": 150, // DATASET-RELATED: maximum text length
"output_path": "/media/erogol/data_ssd/Data/models/ljspeech_models/", // DATASET-RELATED: output path for all training outputs. "output_path": "/media/erogol/data_ssd/Data/models/ljspeech_models/", // DATASET-RELATED: output path for all training outputs.

View File

@ -22,7 +22,6 @@ class MyDataset(Dataset):
batch_group_size=0, batch_group_size=0,
min_seq_len=0, min_seq_len=0,
max_seq_len=float("inf"), max_seq_len=float("inf"),
cached=False,
use_phonemes=True, use_phonemes=True,
phoneme_cache_path=None, phoneme_cache_path=None,
phoneme_language="en-us", phoneme_language="en-us",
@ -61,7 +60,6 @@ class MyDataset(Dataset):
self.min_seq_len = min_seq_len self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.ap = ap self.ap = ap
self.cached = cached
self.use_phonemes = use_phonemes self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language self.phoneme_language = phoneme_language
@ -110,23 +108,8 @@ class MyDataset(Dataset):
return text return text
def load_data(self, idx): def load_data(self, idx):
if self.cached: text, wav_file = self.items[idx]
wav_name = self.items[idx][1] wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
mel_name = self.items[idx][2]
linear_name = self.items[idx][3]
text = self.items[idx][0]
if wav_name.split('.')[-1] == 'npy':
wav = self.load_np(wav_name)
else:
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
mel = self.load_np(mel_name)
linear = self.load_np(linear_name)
else:
text, wav_file = self.items[idx]
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
mel = None
linear = None
if self.use_phonemes: if self.use_phonemes:
text = self.load_phoneme_sequence(wav_file, text) text = self.load_phoneme_sequence(wav_file, text)
@ -140,9 +123,7 @@ class MyDataset(Dataset):
sample = { sample = {
'text': text, 'text': text,
'wav': wav, 'wav': wav,
'item_idx': self.items[idx][1], 'item_idx': self.items[idx][1]
'mel': mel,
'linear': linear
} }
return sample return sample
@ -205,17 +186,9 @@ class MyDataset(Dataset):
] ]
text = [batch[idx]['text'] for idx in ids_sorted_decreasing] text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
# if specs are not computed, compute them. mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
if batch[0]['mel'] is None and batch[0]['linear'] is None: linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
mel = [
self.ap.melspectrogram(w).astype('float32') for w in wav
]
linear = [
self.ap.spectrogram(w).astype('float32') for w in wav
]
else:
mel = [d['mel'] for d in batch]
linear = [d['linear'] for d in batch]
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets # compute 'stop token' targets

View File

@ -1,18 +1,6 @@
import os import os
def tts_cache(root_path, meta_file):
"""This format is set for the meta-file generated by extract_features.py"""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, 'r', encoding='utf8') as f:
for line in f:
cols = line.split('| ')
# text, wav_full_path, mel_name, linear_name, wav_len, mel_len
items.append(cols)
return items
def tweb(root_path, meta_file): def tweb(root_path, meta_file):
"""Normalize TWEB dataset. """Normalize TWEB dataset.
https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset

View File

@ -1,126 +0,0 @@
'''
Extract spectrograms and save them to file for training
'''
import os
import sys
import time
import glob
import argparse
import librosa
import importlib
import numpy as np
import tqdm
from utils.generic_utils import load_config, copy_config_file
from utils.audio import AudioProcessor
from multiprocessing import Pool
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, help='Data folder.')
parser.add_argument('--cache_path', type=str, help='Cache folder, place to output all the spectrogram files.')
parser.add_argument(
'--config', type=str, help='conf.json file for run settings.')
parser.add_argument(
"--num_proc", type=int, default=8, help="number of processes.")
parser.add_argument(
"--trim_silence",
type=bool,
default=False,
help="trim silence in the voice clip.")
parser.add_argument("--only_mel", type=bool, default=False, help="If True, only melsceptrogram is extracted.")
parser.add_argument("--dataset", type=str, help="Target dataset to be processed.")
parser.add_argument("--val_split", type=int, default=0, help="Number of instances for validation.")
parser.add_argument("--meta_file", type=str, help="Meta data file to be used for the dataset.")
parser.add_argument("--process_audio", type=bool, default=False, help="Preprocess audio files.")
args = parser.parse_args()
DATA_PATH = args.data_path
CACHE_PATH = args.cache_path
CONFIG = load_config(args.config)
# load the right preprocessor
preprocessor = importlib.import_module('datasets.preprocess')
preprocessor = getattr(preprocessor, args.dataset.lower())
items = preprocessor(args.data_path, args.meta_file)
print(" > Input path: ", DATA_PATH)
print(" > Cache path: ", CACHE_PATH)
ap = AudioProcessor(**CONFIG.audio)
def extract_mel(item):
""" Compute spectrograms, length information """
text = item[0]
file_path = item[1]
x = ap.load_wav(file_path, ap.sample_rate)
file_name = os.path.basename(file_path).replace(".wav", "")
mel_file = file_name + "_mel"
mel_path = os.path.join(CACHE_PATH, 'mel', mel_file)
mel = ap.melspectrogram(x.astype('float32')).astype('float32')
np.save(mel_path, mel, allow_pickle=False)
mel_len = mel.shape[1]
wav_len = x.shape[0]
output = [text, file_path, mel_path+".npy", str(wav_len), str(mel_len)]
if not args.only_mel:
linear_file = file_name + "_linear"
linear_path = os.path.join(CACHE_PATH, 'linear', linear_file)
linear = ap.spectrogram(x.astype('float32')).astype('float32')
linear_len = linear.shape[1]
np.save(linear_path, linear, allow_pickle=False)
output.insert(3, linear_path+".npy")
assert mel_len == linear_len
if args.process_audio:
audio_file = file_name + "_audio"
audio_path = os.path.join(CACHE_PATH, 'audio', audio_file)
np.save(audio_path, x, allow_pickle=False)
del output[0]
output.insert(1, audio_path+".npy")
return output
if __name__ == "__main__":
print(" > Number of files: %i" % (len(items)))
if not os.path.exists(CACHE_PATH):
os.makedirs(os.path.join(CACHE_PATH, 'mel'))
if not args.only_mel:
os.makedirs(os.path.join(CACHE_PATH, 'linear'))
if args.process_audio:
os.makedirs(os.path.join(CACHE_PATH, 'audio'))
print(" > A new folder created at {}".format(CACHE_PATH))
# Extract features
r = []
if args.num_proc > 1:
print(" > Using {} processes.".format(args.num_proc))
with Pool(args.num_proc) as p:
r = list(
tqdm.tqdm(
p.imap(extract_mel, items),
total=len(items)))
# r = list(p.imap(extract_mel, file_names))
else:
print(" > Using single process run.")
for item in items:
print(" > ", item[1])
r.append(extract_mel(item))
# Save meta data
if args.cache_path is not None:
file_path = os.path.join(CACHE_PATH, "tts_metadata_val.csv")
file = open(file_path, "w")
for line in r[:args.val_split]:
line = "| ".join(line)
file.write(line + '\n')
file.close()
file_path = os.path.join(CACHE_PATH, "tts_metadata.csv")
file = open(file_path, "w")
for line in r[args.val_split:]:
line = "| ".join(line)
file.write(line + '\n')
file.close()
# copy the used config file to output path for sanity
copy_config_file(args.config, CACHE_PATH)

View File

@ -36,7 +36,6 @@
"save_step": 200, "save_step": 200,
"data_path": "/home/erogol/Data/LJSpeech-1.1/", "data_path": "/home/erogol/Data/LJSpeech-1.1/",
"data_path_cache": "/media/erogol/data_ssd/Data/Nancy/tts_cache/",
"output_path": "result", "output_path": "result",
"min_seq_len": 0, "min_seq_len": 0,
"max_seq_len": 300, "max_seq_len": 300,

View File

@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
from utils.generic_utils import load_config from utils.generic_utils import load_config
from utils.audio import AudioProcessor from utils.audio import AudioProcessor
from datasets import TTSDataset from datasets import TTSDataset
from datasets.preprocess import ljspeech, tts_cache from datasets.preprocess import ljspeech
file_path = os.path.dirname(os.path.realpath(__file__)) file_path = os.path.dirname(os.path.realpath(__file__))
OUTPATH = os.path.join(file_path, "outputs/loader_tests/") OUTPATH = os.path.join(file_path, "outputs/loader_tests/")
@ -16,15 +16,11 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
ok_ljspeech = os.path.exists(c.data_path) ok_ljspeech = os.path.exists(c.data_path)
DATA_EXIST = True DATA_EXIST = True
CACHE_EXIST = True
if not os.path.exists(c.data_path_cache):
CACHE_EXIST = False
if not os.path.exists(c.data_path): if not os.path.exists(c.data_path):
DATA_EXIST = False DATA_EXIST = False
print(" > Dynamic data loader test: {}".format(DATA_EXIST)) print(" > Dynamic data loader test: {}".format(DATA_EXIST))
print(" > Cache data loader test: {}".format(CACHE_EXIST))
class TestTTSDataset(unittest.TestCase): class TestTTSDataset(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -126,8 +122,9 @@ class TestTTSDataset(unittest.TestCase):
wav = self.ap.load_wav(item_idx[0]) wav = self.ap.load_wav(item_idx[0])
mel = self.ap.melspectrogram(wav) mel = self.ap.melspectrogram(wav)
mel_dl = mel_input[0].cpu().numpy() mel_dl = mel_input[0].cpu().numpy()
assert ( assert (abs(mel.T).astype("float32")
abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0 - abs(mel_dl[:-1])
).sum() == 0
# check mel-spec correctness # check mel-spec correctness
mel_spec = mel_input[0].cpu().numpy() mel_spec = mel_input[0].cpu().numpy()
@ -139,7 +136,8 @@ class TestTTSDataset(unittest.TestCase):
linear_spec = linear_input[0].cpu().numpy() linear_spec = linear_input[0].cpu().numpy()
wav = self.ap.inv_spectrogram(linear_spec.T) wav = self.ap.inv_spectrogram(linear_spec.T)
self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader.wav') self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader.wav')
shutil.copy(item_idx[0], OUTPATH + '/linear_target_dataloader.wav') shutil.copy(item_idx[0],
OUTPATH + '/linear_target_dataloader.wav')
# check the last time step to be zero padded # check the last time step to be zero padded
assert linear_input[0, -1].sum() == 0 assert linear_input[0, -1].sum() == 0

View File

@ -53,7 +53,6 @@ def setup_loader(is_val=False, verbose=False):
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
min_seq_len=0 if is_val else c.min_seq_len, min_seq_len=0 if is_val else c.min_seq_len,
max_seq_len=float("inf") if is_val else c.max_seq_len, max_seq_len=float("inf") if is_val else c.max_seq_len,
cached=False if c.dataset != "tts_cache" else True,
phoneme_cache_path=c.phoneme_cache_path, phoneme_cache_path=c.phoneme_cache_path,
use_phonemes=c.use_phonemes, use_phonemes=c.use_phonemes,
phoneme_language=c.phoneme_language, phoneme_language=c.phoneme_language,