Batch update after data-loss

This commit is contained in:
Eren Golge 2018-11-02 16:13:51 +01:00
parent f53f9cb360
commit c8a552e627
18 changed files with 1362 additions and 607 deletions

View File

@ -1,41 +1,47 @@
{ {
"model_name": "TTS-sigmoid", "model_name": "TTS-master-_tmp",
"model_description": "Net outputting Sigmoid unit", "model_description": "Higher dropout rate for stopnet and disabled custom initialization, pull current mel prediction to stopnet.",
"audio_processor": "audio",
"num_mels": 80, "audio":{
"num_freq": 1025, "audio_processor": "audio", // to use dictate different audio processors, if available.
"sample_rate": 22000, "num_mels": 80, // size of the mel spec frame.
"frame_length_ms": 50, "num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
"frame_shift_ms": 12.5, "sample_rate": 22050, // wav sample-rate. If different than the original data, it is resampled.
"preemphasis": 0.97, "frame_length_ms": 50, // stft window length in ms.
"min_level_db": -100, "frame_shift_ms": 12.5, // stft window hop-lengh in ms.
"ref_level_db": 20, "preemphasis": 0.97, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"min_level_db": -100, // normalization range
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
"power": 1.5, // value to sharpen wav signals after GL algorithm.
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
"signal_norm": true, // normalize the spec values in range [0, 1]
"symmetric_norm": false, // move normalization to range [-1, 1]
"max_norm": 1, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"mel_fmin": null, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": null // maximum freq level for mel-spec. Tune for dataset!!
},
"embedding_size": 256, "embedding_size": 256,
"text_cleaner": "english_cleaners", "text_cleaner": "english_cleaners",
"num_loader_workers": 4,
"epochs": 1000, "epochs": 1000,
"lr": 0.002, "lr": 0.0015,
"warmup_steps": 4000, "warmup_steps": 4000,
"lr_decay": 0.5, "batch_size":32,
"decay_step": 100000, "eval_batch_size":32,
"batch_size": 32, "r": 1,
"eval_batch_size":-1, "wd": 0.000001,
"r": 5,
"wd": 0.0001,
"griffin_lim_iters": 60,
"power": 1.5,
"checkpoint": true, "checkpoint": true,
"save_step": 25000, "save_step": 5000,
"print_step": 10, "print_step": 10,
"run_eval": false,
"data_path": "/snakepit/shared/data/keithito/LJSpeech-1.1/", "run_eval": true,
"meta_file_train": "metadata.csv", "data_path": "../../Data/LJSpeech-1.1/tts_cache", // can overwritten from command argument
"meta_file_val": null, "meta_file_train": "metadata_train.csv", // metafile for training dataloader
"dataset": "LJSpeech", "meta_file_val": "metadata_val.csv", // metafile for validation dataloader
"data_loader": "TTSDataset", // dataloader, ["TTSDataset", "TTSDatasetCached", "TTSDatasetMemory"]
"dataset": "ljspeech", // one of TTS.dataset.preprocessors, only valid id dataloader == "TTSDataset", rest uses "tts_cache" by default.
"min_seq_len": 0, "min_seq_len": 0,
"output_path": "../keep/" "output_path": "../keep/",
"num_loader_workers": 8
} }

View File

@ -13,75 +13,72 @@ from utils.data import (prepare_data, pad_per_step, prepare_tensor,
class MyDataset(Dataset): class MyDataset(Dataset):
def __init__(self, def __init__(self,
root_dir, root_path,
csv_file, meta_file,
outputs_per_step, outputs_per_step,
text_cleaner, text_cleaner,
ap, ap,
preprocessor,
batch_group_size=0, batch_group_size=0,
min_seq_len=0): min_seq_len=0):
self.root_dir = root_dir self.root_path = root_path
self.batch_group_size = batch_group_size self.batch_group_size = batch_group_size
self.wav_dir = os.path.join(root_dir, 'wavs') self.items = preprocessor(root_path, meta_file)
self.csv_dir = os.path.join(root_dir, csv_file)
with open(self.csv_dir, "r", encoding="utf8") as f:
self.frames = [line.split('|') for line in f]
self.outputs_per_step = outputs_per_step self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner self.cleaners = text_cleaner
self.min_seq_len = min_seq_len self.min_seq_len = min_seq_len
self.ap = ap self.ap = ap
print(" > Reading LJSpeech from - {}".format(root_dir)) print(" > Reading LJSpeech from - {}".format(root_path))
print(" | > Number of instances : {}".format(len(self.frames))) print(" | > Number of instances : {}".format(len(self.items)))
self.sort_frames() self.sort_items()
def load_wav(self, filename): def load_wav(self, filename):
try: try:
audio = librosa.core.load(filename, sr=self.sample_rate)[0] audio = self.ap.load_wav(filename)
return audio return audio
except RuntimeError as e: except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename)) print(" !! Cannot read file : {}".format(filename))
def sort_frames(self): def sort_items(self):
r"""Sort text sequences in ascending order""" r"""Sort text sequences in ascending order"""
lengths = np.array([len(ins[1]) for ins in self.frames]) lengths = np.array([len(ins[0]) for ins in self.items])
print(" | > Max length sequence {}".format(np.max(lengths))) print(" | > Max length sequence {}".format(np.max(lengths)))
print(" | > Min length sequence {}".format(np.min(lengths))) print(" | > Min length sequence {}".format(np.min(lengths)))
print(" | > Avg length sequence {}".format(np.mean(lengths))) print(" | > Avg length sequence {}".format(np.mean(lengths)))
idxs = np.argsort(lengths) idxs = np.argsort(lengths)
new_frames = [] new_items = []
ignored = [] ignored = []
for i, idx in enumerate(idxs): for i, idx in enumerate(idxs):
length = lengths[idx] length = lengths[idx]
if length < self.min_seq_len: if length < self.min_seq_len:
ignored.append(idx) ignored.append(idx)
else: else:
new_frames.append(self.frames[idx]) new_items.append(self.items[idx])
print(" | > {} instances are ignored by min_seq_len ({})".format( print(" | > {} instances are ignored by min_seq_len ({})".format(
len(ignored), self.min_seq_len)) len(ignored), self.min_seq_len))
# shuffle batch groups # shuffle batch groups
if self.batch_group_size > 0: if self.batch_group_size > 0:
print(" | > Batch group shuffling is active.") print(" | > Batch group shuffling is active.")
for i in range(len(new_frames) // self.batch_group_size): for i in range(len(new_items) // self.batch_group_size):
offset = i * self.batch_group_size offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size end_offset = offset + self.batch_group_size
temp_frames = new_frames[offset : end_offset] temp_items = new_items[offset : end_offset]
random.shuffle(temp_frames) random.shuffle(temp_items)
new_frames[offset : end_offset] = temp_frames new_items[offset : end_offset] = temp_items
self.frames = new_frames self.items = new_items
def __len__(self): def __len__(self):
return len(self.frames) return len(self.items)
def __getitem__(self, idx): def __getitem__(self, idx):
wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav' text, wav_file = self.items[idx]
text = self.frames[idx][1]
text = np.asarray( text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32) text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32) wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]} sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1]}
return sample return sample
def collate_fn(self, batch): def collate_fn(self, batch):

View File

@ -1,4 +1,5 @@
import os import os
import random
import numpy as np import numpy as np
import collections import collections
import librosa import librosa
@ -6,32 +7,34 @@ import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from utils.text import text_to_sequence from utils.text import text_to_sequence
from datasets.preprocess import tts_cache
from utils.data import (prepare_data, pad_per_step, prepare_tensor, from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_stop_target) prepare_stop_target)
class MyDataset(Dataset): class MyDataset(Dataset):
# TODO: Not finished yet.
def __init__(self, def __init__(self,
root_dir, root_path,
csv_file, meta_file,
outputs_per_step, outputs_per_step,
text_cleaner, text_cleaner,
ap, ap,
min_seq_len=0): batch_group_size=0,
self.root_dir = root_dir min_seq_len=0,
self.wav_dir = os.path.join(root_dir, 'wavs') **kwargs
self.feat_dir = os.path.join(root_dir, 'loader_data') ):
self.csv_dir = os.path.join(root_dir, csv_file) self.root_path = root_path
with open(self.csv_dir, "r", encoding="utf8") as f: self.batch_group_size = batch_group_size
self.frames = [line.split('|') for line in f] self.feat_dir = os.path.join(root_path, 'loader_data')
self.items = tts_cache(root_path, meta_file)
self.outputs_per_step = outputs_per_step self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner self.cleaners = text_cleaner
self.min_seq_len = min_seq_len self.min_seq_len = min_seq_len
self.items = [None] * len(self.frames) print(" > Reading LJSpeech from - {}".format(root_path))
print(" > Reading LJSpeech from - {}".format(root_dir)) print(" | > Number of instances : {}".format(len(self.items)))
print(" | > Number of instances : {}".format(len(self.frames))) self.sort_items()
self._sort_frames()
def load_wav(self, filename): def load_wav(self, filename):
try: try:
@ -44,9 +47,9 @@ class MyDataset(Dataset):
data = np.load(filename).astype('float32') data = np.load(filename).astype('float32')
return data return data
def _sort_frames(self): def sort_items(self):
r"""Sort sequences in ascending order""" r"""Sort text sequences in ascending order"""
lengths = np.array([len(ins[1]) for ins in self.frames]) lengths = np.array([len(ins[-1]) for ins in self.items])
print(" | > Max length sequence {}".format(np.max(lengths))) print(" | > Max length sequence {}".format(np.max(lengths)))
print(" | > Min length sequence {}".format(np.min(lengths))) print(" | > Min length sequence {}".format(np.min(lengths)))
@ -60,37 +63,43 @@ class MyDataset(Dataset):
if length < self.min_seq_len: if length < self.min_seq_len:
ignored.append(idx) ignored.append(idx)
else: else:
new_frames.append(self.frames[idx]) new_frames.append(self.items[idx])
print(" | > {} instances are ignored by min_seq_len ({})".format( print(" | > {} instances are ignored by min_seq_len ({})".format(
len(ignored), self.min_seq_len)) len(ignored), self.min_seq_len))
self.frames = new_frames # shuffle batch groups
if self.batch_group_size > 0:
print(" | > Batch group shuffling is active.")
for i in range(len(new_frames) // self.batch_group_size):
offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size
temp_frames = new_frames[offset : end_offset]
random.shuffle(temp_frames)
new_frames[offset : end_offset] = temp_frames
self.items = new_frames
def __len__(self): def __len__(self):
return len(self.frames) return len(self.items)
def __getitem__(self, idx): def __getitem__(self, idx):
if self.items[idx] is None: wav_name = self.items[idx][0]
wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav' mel_name = self.items[idx][1]
mel_name = os.path.join(self.feat_dir, linear_name = self.items[idx][2]
self.frames[idx][0]) + '.mel.npy' text = self.items[idx][-1]
linear_name = os.path.join(self.feat_dir,
self.frames[idx][0]) + '.linear.npy'
text = self.frames[idx][1]
text = np.asarray( text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32) text_to_sequence(text, [self.cleaners]), dtype=np.int32)
if wav_name.split('.')[-1] == 'npy':
wav = self.load_np(wav_name)
else:
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32) wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
mel = self.load_np(mel_name) mel = self.load_np(mel_name)
linear = self.load_np(linear_name) linear = self.load_np(linear_name)
sample = { sample = {
'text': text, 'text': text,
'wav': wav, 'wav': wav,
'item_idx': self.frames[idx][0], 'item_idx': self.items[idx][0],
'mel': mel, 'mel': mel,
'linear': linear 'linear': linear
} }
self.items[idx] = sample
else:
sample = self.items[idx]
return sample return sample
def collate_fn(self, batch): def collate_fn(self, batch):
@ -132,7 +141,6 @@ class MyDataset(Dataset):
# PAD features with largest length + a zero frame # PAD features with largest length + a zero frame
linear = prepare_tensor(linear, self.outputs_per_step) linear = prepare_tensor(linear, self.outputs_per_step)
mel = prepare_tensor(mel, self.outputs_per_step) mel = prepare_tensor(mel, self.outputs_per_step)
assert mel.shape[2] == linear.shape[2]
timesteps = mel.shape[2] timesteps = mel.shape[2]
# B x T x D # B x T x D
@ -147,8 +155,7 @@ class MyDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths) mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets) stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[ return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs
0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}".format(type(batch[0])))) found {}".format(type(batch[0]))))

View File

@ -1,72 +1,81 @@
import os import os
import glob
import random import random
import numpy as np import numpy as np
import collections import collections
import librosa import librosa
import torch import torch
from tqdm import tqdm
from torch.utils.data import Dataset from torch.utils.data import Dataset
from utils.text import text_to_sequence from utils.text import text_to_sequence
from datasets.preprocess import tts_cache
from utils.data import (prepare_data, pad_per_step, prepare_tensor, from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_stop_target) prepare_stop_target)
class MyDataset(Dataset): class MyDataset(Dataset):
# TODO: Not finished yet.
def __init__(self, def __init__(self,
root_dir, root_path,
csv_file, meta_file,
outputs_per_step, outputs_per_step,
text_cleaner, text_cleaner,
ap, ap,
min_seq_len=0): batch_group_size=0,
self.root_dir = root_dir min_seq_len=0,
self.wav_dir = os.path.join(root_dir, 'wav') **kwargs
self.wav_files = glob.glob(os.path.join(self.wav_dir, '*.wav')) ):
self._create_file_dict() self.root_path = root_path
self.csv_dir = os.path.join(root_dir, csv_file) self.batch_group_size = batch_group_size
with open(self.csv_dir, "r", encoding="utf8") as f: self.feat_dir = os.path.join(root_path, 'loader_data')
self.frames = [ self.items = tts_cache(root_path, meta_file)
line.split('\t') for line in f
if line.split('\t')[0] in self.wav_files_dict.keys()
]
self.outputs_per_step = outputs_per_step self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner self.cleaners = text_cleaner
self.min_seq_len = min_seq_len self.min_seq_len = min_seq_len
self.ap = ap self.wavs = None
print(" > Reading Kusal from - {}".format(root_dir)) self.mels = None
print(" | > Number of instances : {}".format(len(self.frames))) self.linears = None
self._sort_frames() print(" > Reading LJSpeech from - {}".format(root_path))
print(" | > Number of instances : {}".format(len(self.items)))
self.sort_items()
self.fill_data()
def fill_data(self):
if self.wavs is None and self.mels is None:
self.wavs = []
self.mels = []
self.linears = []
self.texts = []
for item in tqdm(self.items):
wav_file = item[0]
mel_file = item[1]
linear_file = item[2]
text = item[-1]
wav = self.load_np(wav_file)
mel = self.load_np(mel_file)
linear = self.load_np(linear_file)
self.wavs.append(wav)
self.mels.append(mel)
self.linears.append(linear)
self.texts.append(np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32))
print(" > Data loaded to memory")
def load_wav(self, filename): def load_wav(self, filename):
""" Load audio and trim silence """
try: try:
audio = librosa.core.load(filename, sr=self.sample_rate)[0] audio = librosa.core.load(filename, sr=self.sample_rate)
margin = int(self.sample_rate * 0.1) return audio
audio = audio[margin:-margin]
return self._trim_silence(audio)
except RuntimeError as e: except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename)) print(" !! Cannot read file : {}".format(filename))
def _trim_silence(self, wav): def load_np(self, filename):
return librosa.effects.trim( data = np.load(filename).astype('float32')
wav, top_db=40, frame_length=1024, hop_length=256)[0] return data
def _create_file_dict(self): def sort_items(self):
self.wav_files_dict = {} r"""Sort text sequences in ascending order"""
for fn in self.wav_files: lengths = np.array([len(ins[-1]) for ins in self.items])
parts = fn.split('-')
key = parts[1]
value = fn
try:
self.wav_files_dict[key].append(value)
except:
self.wav_files_dict[key] = [value]
def _sort_frames(self):
r"""Sort sequences in ascending order"""
lengths = np.array([len(ins[2]) for ins in self.frames])
print(" | > Max length sequence {}".format(np.max(lengths))) print(" | > Max length sequence {}".format(np.max(lengths)))
print(" | > Min length sequence {}".format(np.min(lengths))) print(" | > Min length sequence {}".format(np.min(lengths)))
@ -80,24 +89,35 @@ class MyDataset(Dataset):
if length < self.min_seq_len: if length < self.min_seq_len:
ignored.append(idx) ignored.append(idx)
else: else:
new_frames.append(self.frames[idx]) new_frames.append(self.items[idx])
print(" | > {} instances are ignored by min_seq_len ({})".format( print(" | > {} instances are ignored by min_seq_len ({})".format(
len(ignored), self.min_seq_len)) len(ignored), self.min_seq_len))
self.frames = new_frames # shuffle batch groups
if self.batch_group_size > 0:
print(" | > Batch group shuffling is active.")
for i in range(len(new_frames) // self.batch_group_size):
offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size
temp_frames = new_frames[offset : end_offset]
random.shuffle(temp_frames)
new_frames[offset : end_offset] = temp_frames
self.items = new_frames
def __len__(self): def __len__(self):
return len(self.frames) return len(self.items)
def __getitem__(self, idx): def __getitem__(self, idx):
sidx = self.frames[idx][0] text = self.texts[idx]
sidx_files = self.wav_files_dict[sidx] wav = self.wavs[idx]
file_name = random.choice(sidx_files) mel = self.mels[idx]
wav_name = os.path.join(self.wav_dir, file_name) linear = self.linears[idx]
text = self.frames[idx][2] sample = {
text = np.asarray( 'text': text,
text_to_sequence(text, [self.cleaners]), dtype=np.int32) 'wav': wav,
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32) 'item_idx': self.items[idx][0],
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]} 'mel': mel,
'linear': linear
}
return sample return sample
def collate_fn(self, batch): def collate_fn(self, batch):
@ -116,12 +136,11 @@ class MyDataset(Dataset):
wav = [d['wav'] for d in batch] wav = [d['wav'] for d in batch]
item_idxs = [d['item_idx'] for d in batch] item_idxs = [d['item_idx'] for d in batch]
text = [d['text'] for d in batch] text = [d['text'] for d in batch]
mel = [d['mel'] for d in batch]
linear = [d['linear'] for d in batch]
text_lenghts = np.array([len(x) for x in text]) text_lenghts = np.array([len(x) for x in text])
max_text_len = np.max(text_lenghts) max_text_len = np.max(text_lenghts)
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets # compute 'stop token' targets
@ -140,7 +159,6 @@ class MyDataset(Dataset):
# PAD features with largest length + a zero frame # PAD features with largest length + a zero frame
linear = prepare_tensor(linear, self.outputs_per_step) linear = prepare_tensor(linear, self.outputs_per_step)
mel = prepare_tensor(mel, self.outputs_per_step) mel = prepare_tensor(mel, self.outputs_per_step)
assert mel.shape[2] == linear.shape[2]
timesteps = mel.shape[2] timesteps = mel.shape[2]
# B x T x D # B x T x D
@ -155,8 +173,7 @@ class MyDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths) mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets) stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[ return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs
0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}".format(type(batch[0])))) found {}".format(type(batch[0]))))

60
datasets/preprocess.py Normal file
View File

@ -0,0 +1,60 @@
import os
import random
def tts_cache(root_path, meta_file):
"""This format is set for the meta-file generated by extract_features.py"""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, 'r', encoding='utf8') as f:
for line in f:
cols = line.split('| ')
items.append(cols) # wav_full_path, mel_name, linear_name, wav_len, mel_len, text
random.shuffle(items)
return items
# def tweb(root_path, meta_file):
# # TODO
# pass
# return
# def kusal(root_path, meta_file):
# txt_file = os.path.join(root_path, meta_file)
# texts = []
# wavs = []
# with open(txt_file, "r", encoding="utf8") as f:
# frames = [
# line.split('\t') for line in f
# if line.split('\t')[0] in self.wav_files_dict.keys()
# ]
# # TODO: code the rest
# return {'text': texts, 'wavs': wavs}
def ljspeech(root_path, meta_file):
"""Normalizes the Nancy meta data file to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, 'r') as ttf:
for line in ttf:
cols = line.split('|')
wav_file = os.path.join(root_path, 'wavs', cols[0]+'.wav')
text = cols[1]
items.append([text, wav_file])
random.shuffle(items)
return items
def nancy(root_path, meta_file):
"""Normalizes the Nancy meta data file to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, 'r') as ttf:
for line in ttf:
id = line.split()[1]
text = line[line.find('"')+1:line.rfind('"')-1]
wav_file = root_path + 'wavn/' + id + '.wav'
items.append(text, wav_file)
random.shuffle(items)
return items

138
extract_features.py Normal file
View File

@ -0,0 +1,138 @@
'''
Extract spectrograms and save them to file for training
'''
import os
import sys
import time
import glob
import argparse
import librosa
import importlib
import numpy as np
import tqdm
from utils.generic_utils import load_config, copy_config_file
from utils.audio import AudioProcessor
from multiprocessing import Pool
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, help='Data folder.')
parser.add_argument('--cache_path', type=str, help='Cache folder, place to output all the intermediate spectrogram files.')
# parser.add_argument('--keep_cache', type=bool, help='If True, it keeps the cache folder.')
# parser.add_argument('--hdf5_path', type=str, help='hdf5 folder.')
parser.add_argument(
'--config', type=str, help='conf.json file for run settings.')
parser.add_argument(
"--num_proc", type=int, default=8, help="number of processes.")
parser.add_argument(
"--trim_silence",
type=bool,
default=False,
help="trim silence in the voice clip.")
parser.add_argument("--only_mel", type=bool, default=False, help="If True, only melsceptrogram is extracted.")
parser.add_argument("--dataset", type=str, help="Target dataset to be processed.")
parser.add_argument("--val_split", type=int, default=0, help="Number of instances for validation.")
parser.add_argument("--meta_file", type=str, help="Meta data file to be used for the dataset.")
parser.add_argument("--process_audio", type=bool, default=False, help="Preprocess audio files.")
args = parser.parse_args()
DATA_PATH = args.data_path
CACHE_PATH = args.cache_path
CONFIG = load_config(args.config)
# load the right preprocessor
preprocessor = importlib.import_module('datasets.preprocess')
preprocessor = getattr(preprocessor, args.dataset.lower())
items = preprocessor(args.data_path, args.meta_file)
print(" > Input path: ", DATA_PATH)
print(" > Cache path: ", CACHE_PATH)
# audio = importlib.import_module('utils.' + c.audio_processor)
# AudioProcessor = getattr(audio, 'AudioProcessor')
ap = AudioProcessor(**CONFIG.audio)
def trim_silence(self, wav):
""" Trim silent parts with a threshold and 0.1 sec margin """
margin = int(ap.sample_rate * 0.1)
wav = wav[margin:-margin]
return librosa.effects.trim(
wav, top_db=40, frame_length=1024, hop_length=256)[0]
def extract_mel(item):
""" Compute spectrograms, length information """
text = item[0]
file_path = item[1]
x = ap.load_wav(file_path, ap.sample_rate)
if args.trim_silence:
x = trim_silence(x)
file_name = os.path.basename(file_path).replace(".wav", "")
mel_file = file_name + "_mel"
mel_path = os.path.join(CACHE_PATH, 'mel', mel_file)
mel = ap.melspectrogram(x.astype('float32')).astype('float32')
np.save(mel_path, mel, allow_pickle=False)
mel_len = mel.shape[1]
wav_len = x.shape[0]
output = [file_path, mel_path+".npy", str(wav_len), str(mel_len), text]
if not args.only_mel:
linear_file = file_name + "_linear"
linear_path = os.path.join(CACHE_PATH, 'linear', linear_file)
linear = ap.spectrogram(x.astype('float32')).astype('float32')
linear_len = linear.shape[1]
np.save(linear_path, linear, allow_pickle=False)
output.insert(2, linear_path+".npy")
if args.process_audio:
audio_file = file_name + "_audio"
audio_path = os.path.join(CACHE_PATH, 'audio', audio_file)
np.save(audio_path, x, allow_pickle=False)
del output[0]
output.insert(0, audio_path+".npy")
assert mel_len == linear_len
return output
if __name__ == "__main__":
print(" > Number of files: %i" % (len(items)))
if not os.path.exists(CACHE_PATH):
os.makedirs(os.path.join(CACHE_PATH, 'mel'))
if not args.only_mel:
os.makedirs(os.path.join(CACHE_PATH, 'linear'))
if args.process_audio:
os.makedirs(os.path.join(CACHE_PATH, 'audio'))
print(" > A new folder created at {}".format(CACHE_PATH))
# Extract features
r = []
if args.num_proc > 1:
print(" > Using {} processes.".format(args.num_proc))
with Pool(args.num_proc) as p:
r = list(
tqdm.tqdm(
p.imap(extract_mel, items),
total=len(items)))
# r = list(p.imap(extract_mel, file_names))
else:
print(" > Using single process run.")
for item in items:
print(" > ", item[1])
r.append(extract_mel(item))
# Save meta data
if args.cache_path is not None:
file_path = os.path.join(CACHE_PATH, "tts_metadata_val.csv")
file = open(file_path, "w")
for line in r[:args.val_split]:
line = "| ".join(line)
file.write(line + '\n')
file.close()
file_path = os.path.join(CACHE_PATH, "tts_metadata.csv")
file = open(file_path, "w")
for line in r[args.val_split:]:
line = "| ".join(line)
file.write(line + '\n')
file.close()
# copy the used config file to output path for sanity
copy_config_file(args.config, CACHE_PATH)

View File

@ -52,10 +52,25 @@ class LocationSensitiveAttention(nn.Module):
stride=1, stride=1,
padding=0, padding=0,
bias=False)) bias=False))
self.loc_linear = nn.Linear(filters, attn_dim) self.loc_linear = nn.Linear(filters, attn_dim, bias=True)
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True) self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True) self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
self.v = nn.Linear(attn_dim, 1, bias=False) self.v = nn.Linear(attn_dim, 1, bias=False)
# self.init_layers()
def init_layers(self):
torch.nn.init.xavier_uniform_(
self.loc_linear.weight,
gain=torch.nn.init.calculate_gain('tanh'))
torch.nn.init.xavier_uniform_(
self.query_layer.weight,
gain=torch.nn.init.calculate_gain('tanh'))
torch.nn.init.xavier_uniform_(
self.annot_layer.weight,
gain=torch.nn.init.calculate_gain('tanh'))
torch.nn.init.xavier_uniform_(
self.v.weight,
gain=torch.nn.init.calculate_gain('linear'))
def forward(self, annot, query, loc): def forward(self, annot, query, loc):
""" """

View File

@ -23,6 +23,13 @@ class Prenet(nn.Module):
]) ])
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
# self.init_layers()
def init_layers(self):
for layer in self.layers:
torch.nn.init.xavier_uniform_(
layer.weight,
gain=torch.nn.init.calculate_gain('relu'))
def forward(self, inputs): def forward(self, inputs):
for linear in self.layers: for linear in self.layers:
@ -55,6 +62,7 @@ class BatchNormConv1d(nn.Module):
stride, stride,
padding, padding,
activation=None): activation=None):
super(BatchNormConv1d, self).__init__() super(BatchNormConv1d, self).__init__()
self.padding = padding self.padding = padding
self.padder = nn.ConstantPad1d(padding, 0) self.padder = nn.ConstantPad1d(padding, 0)
@ -68,13 +76,28 @@ class BatchNormConv1d(nn.Module):
# Following tensorflow's default parameters # Following tensorflow's default parameters
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3) self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
self.activation = activation self.activation = activation
# self.init_layers()
def init_layers(self):
if type(self.activation) == torch.nn.ReLU:
w_gain = 'relu'
elif type(self.activation) == torch.nn.Tanh:
w_gain = 'tanh'
elif self.activation is None:
w_gain = 'linear'
else:
raise RuntimeError('Unknown activation function')
torch.nn.init.xavier_uniform_(
self.conv1d.weight,
gain=torch.nn.init.calculate_gain(w_gain))
def forward(self, x): def forward(self, x):
x = self.padder(x) x = self.padder(x)
x = self.conv1d(x) x = self.conv1d(x)
x = self.bn(x)
if self.activation is not None: if self.activation is not None:
x = self.activation(x) x = self.activation(x)
return self.bn(x) return x
class Highway(nn.Module): class Highway(nn.Module):
@ -86,6 +109,15 @@ class Highway(nn.Module):
self.T.bias.data.fill_(-1) self.T.bias.data.fill_(-1)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
# self.init_layers()
def init_layers(self):
torch.nn.init.xavier_uniform_(
self.H.weight,
gain=torch.nn.init.calculate_gain('relu'))
torch.nn.init.xavier_uniform_(
self.T.weight,
gain=torch.nn.init.calculate_gain('sigmoid'))
def forward(self, inputs): def forward(self, inputs):
H = self.relu(self.H(inputs)) H = self.relu(self.H(inputs))
@ -276,7 +308,7 @@ class Decoder(nn.Module):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.r = r self.r = r
self.in_features = in_features self.in_features = in_features
self.max_decoder_steps = 200 self.max_decoder_steps = 500
self.memory_dim = memory_dim self.memory_dim = memory_dim
# memory -> |Prenet| -> processed_memory # memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, out_features=[256, 128]) self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
@ -294,7 +326,16 @@ class Decoder(nn.Module):
[nn.GRUCell(256, 256) for _ in range(2)]) [nn.GRUCell(256, 256) for _ in range(2)])
# RNN_state -> |Linear| -> mel_spec # RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, memory_dim * r) self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.stopnet = StopNet(r, memory_dim) self.stopnet = StopNet(256 + memory_dim * r)
# self.init_layers()
def init_layers(self):
torch.nn.init.xavier_uniform_(
self.project_to_decoder_in.weight,
gain=torch.nn.init.calculate_gain('linear'))
torch.nn.init.xavier_uniform_(
self.proj_to_mel.weight,
gain=torch.nn.init.calculate_gain('linear'))
def forward(self, inputs, memory=None, mask=None): def forward(self, inputs, memory=None, mask=None):
""" """
@ -350,7 +391,7 @@ class Decoder(nn.Module):
memory_input = initial_memory memory_input = initial_memory
while True: while True:
if t > 0: if t > 0:
if greedy: if memory is None:
memory_input = outputs[-1] memory_input = outputs[-1]
else: else:
memory_input = memory[t - 1] memory_input = memory[t - 1]
@ -375,17 +416,14 @@ class Decoder(nn.Module):
decoder_output = decoder_input decoder_output = decoder_input
# predict mel vectors from decoder vectors # predict mel vectors from decoder vectors
output = self.proj_to_mel(decoder_output) output = self.proj_to_mel(decoder_output)
output = torch.sigmoid(output)
stop_input = output
# predict stop token # predict stop token
stop_token, stopnet_rnn_hidden = self.stopnet( stopnet_input = torch.cat([decoder_input, output], -1)
stop_input, stopnet_rnn_hidden) stop_token = self.stopnet(stopnet_input)
outputs += [output] outputs += [output]
attentions += [attention] attentions += [attention]
stop_tokens += [stop_token] stop_tokens += [stop_token]
t += 1 t += 1
if (not greedy and self.training) or (greedy if memory is not None:
and memory is not None):
if t >= T_decoder: if t >= T_decoder:
break break
else: else:
@ -394,7 +432,6 @@ class Decoder(nn.Module):
elif t > self.max_decoder_steps: elif t > self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps") print(" | > Decoder stopped with 'max_decoder_steps")
break break
assert greedy or len(outputs) == T_decoder
# Back to batch first # Back to batch first
attentions = torch.stack(attentions).transpose(0, 1) attentions = torch.stack(attentions).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous() outputs = torch.stack(outputs).transpose(0, 1).contiguous()
@ -407,32 +444,22 @@ class StopNet(nn.Module):
Predicting stop-token in decoder. Predicting stop-token in decoder.
Args: Args:
r (int): number of output frames of the network. in_features (int): feature dimension of input.
memory_dim (int): feature dimension for each output frame.
""" """
def __init__(self, r, memory_dim): def __init__(self, in_features):
r"""
Predicts the stop token to stop the decoder at testing time
Args:
r (int): number of network output frames.
memory_dim (int): single feature dim of a single network output frame.
"""
super(StopNet, self).__init__() super(StopNet, self).__init__()
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r) self.dropout = nn.Dropout(0.5)
self.relu = nn.ReLU() self.linear = nn.Linear(in_features, 1)
self.linear = nn.Linear(r * memory_dim, 1)
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
torch.nn.init.xavier_uniform_(
self.linear.weight,
gain=torch.nn.init.calculate_gain('linear'))
def forward(self, inputs, rnn_hidden): def forward(self, inputs):
""" # rnn_hidden = self.rnn(inputs, rnn_hidden)
Args: # outputs = self.relu(rnn_hidden)
inputs: network output tensor with r x memory_dim feature dimension. outputs = self.dropout(inputs)
rnn_hidden: hidden state of the RNN cell.
"""
rnn_hidden = self.rnn(inputs, rnn_hidden)
outputs = self.relu(rnn_hidden)
outputs = self.linear(outputs) outputs = self.linear(outputs)
outputs = self.sigmoid(outputs) outputs = self.sigmoid(outputs)
return outputs, rnn_hidden return outputs

View File

@ -80,11 +80,11 @@ setup(
"matplotlib==2.0.2", "matplotlib==2.0.2",
"Pillow", "Pillow",
"flask", "flask",
"lws", # "lws",
"tqdm",
], ],
extras_require={ extras_require={
"bin": [ "bin": [
"tqdm",
"requests", "requests",
], ],
}) })

142
tests/audio_tests.py Normal file
View File

@ -0,0 +1,142 @@
import os
import unittest
import numpy as np
import torch as T
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import load_config
file_path = os.path.dirname(os.path.realpath(__file__))
INPUTPATH = os.path.join(file_path, 'inputs')
OUTPATH = os.path.join(file_path, "outputs/audio_tests")
os.makedirs(OUTPATH, exist_ok=True)
c = load_config(os.path.join(file_path, 'test_config.json'))
class TestAudio(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestAudio, self).__init__(*args, **kwargs)
self.ap = AudioProcessor(**c.audio)
def test_audio_synthesis(self):
""" 1. load wav
2. set normalization parameters
3. extract mel-spec
4. invert to wav and save the output
"""
print(" > Sanity check for the process wav -> mel -> wav")
def _test(max_norm, signal_norm, symmetric_norm, clip_norm):
self.ap.max_norm = max_norm
self.ap.signal_norm = signal_norm
self.ap.symmetric_norm = symmetric_norm
self.ap.clip_norm = clip_norm
wav = self.ap.load_wav(INPUTPATH + "/example_1.wav")
mel = self.ap.melspectrogram(wav)
wav_ = self.ap.inv_mel_spectrogram(mel)
file_name = "/audio_test-melspec_max_norm_{}-signal_norm_{}-symmetric_{}-clip_norm_{}.wav"\
.format(max_norm, signal_norm, symmetric_norm, clip_norm)
print(" | > Creating wav file at : ", file_name)
self.ap.save_wav(wav_, OUTPATH + file_name)
# maxnorm = 1.0
_test(1., False, False, False)
_test(1., True, False, False)
_test(1., True, True, False)
_test(1., True, False, True)
_test(1., True, True, True)
# maxnorm = 4.0
_test(4., False, False, False)
_test(4., True, False, False)
_test(4., True, True, False)
_test(4., True, False, True)
_test(4., True, True, True)
def test_normalize(self):
"""Check normalization and denormalization for range values and consistency """
print(" > Testing normalization and denormalization.")
wav = self.ap.load_wav(INPUTPATH + "/example_1.wav")
self.ap.signal_norm = False
x = self.ap.melspectrogram(wav)
x_old = x
self.ap.signal_norm = True
self.ap.symmetric_norm = False
self.ap.clip_norm = False
self.ap.max_norm = 4.0
x_norm = self.ap._normalize(x)
print(x_norm.max(), " -- ", x_norm.min())
assert (x_old - x).sum() == 0
# check value range
assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max()
assert x_norm.min() >= 0 - 1, x_norm.min()
# check denorm.
x_ = self.ap._denormalize(x_norm)
assert (x - x_).sum() < 1e-3, (x - x_).mean()
self.ap.signal_norm = True
self.ap.symmetric_norm = False
self.ap.clip_norm = True
self.ap.max_norm = 4.0
x_norm = self.ap._normalize(x)
print(x_norm.max(), " -- ", x_norm.min())
assert (x_old - x).sum() == 0
# check value range
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
assert x_norm.min() >= 0, x_norm.min()
# check denorm.
x_ = self.ap._denormalize(x_norm)
assert (x - x_).sum() < 1e-3, (x - x_).mean()
self.ap.signal_norm = True
self.ap.symmetric_norm = True
self.ap.clip_norm = False
self.ap.max_norm = 4.0
x_norm = self.ap._normalize(x)
print(x_norm.max(), " -- ", x_norm.min())
assert (x_old - x).sum() == 0
# check value range
assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max()
assert x_norm.min() >= -self.ap.max_norm - 2, x_norm.min()
assert x_norm.min() <= 0, x_norm.min()
# check denorm.
x_ = self.ap._denormalize(x_norm)
assert (x - x_).sum() < 1e-3, (x - x_).mean()
self.ap.signal_norm = True
self.ap.symmetric_norm = True
self.ap.clip_norm = True
self.ap.max_norm = 4.0
x_norm = self.ap._normalize(x)
print(x_norm.max(), " -- ", x_norm.min())
assert (x_old - x).sum() == 0
# check value range
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
assert x_norm.min() >= -self.ap.max_norm, x_norm.min()
assert x_norm.min() <= 0, x_norm.min()
# check denorm.
x_ = self.ap._denormalize(x_norm)
assert (x - x_).sum() < 1e-3, (x - x_).mean()
self.ap.signal_norm = True
self.ap.symmetric_norm = False
self.ap.max_norm = 1.0
x_norm = self.ap._normalize(x)
print(x_norm.max(), " -- ", x_norm.min())
assert (x_old - x).sum() == 0
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
assert x_norm.min() >= 0, x_norm.min()
x_ = self.ap._denormalize(x_norm)
assert (x - x_).sum() < 1e-3
self.ap.signal_norm = True
self.ap.symmetric_norm = True
self.ap.max_norm = 1.0
x_norm = self.ap._normalize(x)
print(x_norm.max(), " -- ", x_norm.min())
assert (x_old - x).sum() == 0
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
assert x_norm.min() >= -self.ap.max_norm, x_norm.min()
assert x_norm.min() < 0, x_norm.min()
x_ = self.ap._denormalize(x_norm)
assert (x - x_).sum() < 1e-3

View File

@ -1,132 +1,49 @@
import os import os
import unittest import unittest
import shutil
import numpy as np import numpy as np
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.utils.generic_utils import load_config from TTS.utils.generic_utils import load_config
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.datasets import LJSpeech, Kusal from TTS.datasets import TTSDataset, TTSDatasetCached, TTSDatasetMemory
from TTS.datasets.preprocess import ljspeech, tts_cache
file_path = os.path.dirname(os.path.realpath(__file__)) file_path = os.path.dirname(os.path.realpath(__file__))
OUTPATH = os.path.join(file_path, "outputs/loader_tests/")
os.makedirs(OUTPATH, exist_ok=True)
c = load_config(os.path.join(file_path, 'test_config.json')) c = load_config(os.path.join(file_path, 'test_config.json'))
ok_kusal = os.path.exists(c.data_path_Kusal) ok_ljspeech = os.path.exists(c.data_path)
ok_ljspeech = os.path.exists(c.data_path_LJSpeech)
class TestLJSpeechDataset(unittest.TestCase): class TestTTSDataset(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TestLJSpeechDataset, self).__init__(*args, **kwargs) super(TestTTSDataset, self).__init__(*args, **kwargs)
self.max_loader_iter = 4 self.max_loader_iter = 4
self.ap = AudioProcessor( self.ap = AudioProcessor(**c.audio)
sample_rate=c.sample_rate,
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis)
def test_loader(self): def _create_dataloader(self, batch_size, r, bgs):
if ok_ljspeech: dataset = TTSDataset.MyDataset(
dataset = LJSpeech.MyDataset( c.data_path,
os.path.join(c.data_path_LJSpeech), 'metadata.csv',
os.path.join(c.data_path_LJSpeech, 'metadata.csv'), r,
c.r,
c.text_cleaner, c.text_cleaner,
preprocessor=ljspeech,
ap=self.ap, ap=self.ap,
batch_group_size=bgs,
min_seq_len=c.min_seq_len) min_seq_len=c.min_seq_len)
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_size=2, batch_size=batch_size,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert linear_input.shape[0] == c.batch_size
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels
def test_batch_group_shuffle(self):
if ok_ljspeech:
dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech),
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
c.r,
c.text_cleaner,
ap=self.ap,
batch_group_size=16,
min_seq_len=c.min_seq_len)
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
frames = dataset.frames
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert linear_input.shape[0] == c.batch_size
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels
dataloader.dataset.sort_frames()
assert frames[0] != dataloader.dataset.frames[0]
def test_padding(self):
if ok_ljspeech:
dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech),
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
1,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
# Test for batch size 1
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False, shuffle=False,
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
drop_last=True, drop_last=True,
num_workers=c.num_loader_workers) num_workers=c.num_loader_workers)
return dataloader, dataset
def test_loader(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(2, c.r, 0)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
@ -139,11 +56,251 @@ class TestLJSpeechDataset(unittest.TestCase):
stop_target = data[5] stop_target = data[5]
item_idx = data[6] item_idx = data[6]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert linear_input.shape[0] == c.batch_size
assert linear_input.shape[2] == self.ap.num_freq
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.audio['num_mels']
# check normalization ranges
if self.ap.symmetric_norm:
assert mel_input.max() <= self.ap.max_norm
assert mel_input.min() >= -self.ap.max_norm
assert mel_input.min() < 0
else:
assert mel_input.max() <= self.ap.max_norm
assert mel_input.min() >= 0
def test_batch_group_shuffle(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(2, c.r, 16)
last_length = 0
frames = dataset.items
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
avg_length = mel_lengths.numpy().mean()
assert avg_length >= last_length
dataloader.dataset.sort_items()
assert frames[0] != dataloader.dataset.items[0]
def test_padding_and_spec(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(1, 1, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
# check mel_spec consistency
wav = self.ap.load_wav(item_idx[0])
mel = self.ap.melspectrogram(wav)
mel_dl = mel_input[0].cpu().numpy()
assert (
abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0
# check mel-spec correctness
mel_spec = mel_input[0].cpu().numpy()
wav = self.ap.inv_mel_spectrogram(mel_spec.T)
self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader.wav')
shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader.wav')
# check linear-spec
linear_spec = linear_input[0].cpu().numpy()
wav = self.ap.inv_spectrogram(linear_spec.T)
self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader.wav')
shutil.copy(item_idx[0], OUTPATH + '/linear_target_dataloader.wav')
# check the last time step to be zero padded
assert linear_input[0, -1].sum() == 0
assert linear_input[0, -2].sum() != 0
assert mel_input[0, -1].sum() == 0
assert mel_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0
assert stop_target.sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[0] == linear_input[0].shape[0]
assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2
dataloader, dataset = self._create_dataloader(2, 1, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
if mel_lengths[0] > mel_lengths[1]:
idx = 0
else:
idx = 1
# check the first item in the batch
assert linear_input[idx, -1].sum() == 0
assert linear_input[idx, -2].sum() != 0, linear_input
assert mel_input[idx, -1].sum() == 0
assert mel_input[idx, -2].sum() != 0, mel_input
assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[idx] == mel_input[idx].shape[0]
assert mel_lengths[idx] == linear_input[idx].shape[0]
# check the second itme in the batch
assert linear_input[1 - idx, -1].sum() == 0
assert mel_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1
# check batch conditions
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
class TestTTSDatasetCached(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestTTSDatasetCached, self).__init__(*args, **kwargs)
self.max_loader_iter = 4
self.c = load_config(os.path.join(c.data_path_cache, 'config.json'))
self.ap = AudioProcessor(**self.c.audio)
def _create_dataloader(self, batch_size, r, bgs):
dataset = TTSDatasetCached.MyDataset(
c.data_path_cache,
'tts_metadata.csv',
r,
c.text_cleaner,
preprocessor=tts_cache,
ap=self.ap,
batch_group_size=bgs,
min_seq_len=c.min_seq_len)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
return dataloader, dataset
def test_loader(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(2, c.r, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.audio['num_mels']
if self.ap.symmetric_norm:
assert mel_input.max() <= self.ap.max_norm
assert mel_input.min() >= -self.ap.max_norm
assert mel_input.min() < 0
else:
assert mel_input.max() <= self.ap.max_norm
assert mel_input.min() >= 0
def test_batch_group_shuffle(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(2, c.r, 16)
frames = dataset.items
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.audio['num_mels']
dataloader.dataset.sort_items()
assert frames[0] != dataloader.dataset.items[0]
def test_padding_and_spec(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(1, 1, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
# check mel_spec consistency
if item_idx[0].split('.')[-1] == 'npy':
wav = np.load(item_idx[0])
else:
wav = self.ap.load_wav(item_idx[0])
mel = self.ap.melspectrogram(wav)
mel_dl = mel_input[0].cpu().numpy()
assert (abs(mel.T).astype("float32") - abs(
mel_dl[:-1])).sum() == 0, (
abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum()
# check mel-spec correctness
mel_spec = mel_input[0].cpu().numpy()
wav = self.ap.inv_mel_spectrogram(mel_spec.T)
self.ap.save_wav(wav,
OUTPATH + '/mel_inv_dataloader_cache.wav')
shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader_cache.wav')
# check the last time step to be zero padded # check the last time step to be zero padded
assert mel_input[0, -1].sum() == 0 assert mel_input[0, -1].sum() == 0
assert mel_input[0, -2].sum() != 0 assert mel_input[0, -2].sum() != 0
assert linear_input[0, -1].sum() == 0
assert linear_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1 assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0 assert stop_target[0, -2] == 0
assert stop_target.sum() == 1 assert stop_target.sum() == 1
@ -151,14 +308,7 @@ class TestLJSpeechDataset(unittest.TestCase):
assert mel_lengths[0] == mel_input[0].shape[0] assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2 # Test for batch size 2
dataloader = DataLoader( dataloader, dataset = self._create_dataloader(2, 1, 0)
dataset,
batch_size=2,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
break break
@ -178,8 +328,6 @@ class TestLJSpeechDataset(unittest.TestCase):
# check the first item in the batch # check the first item in the batch
assert mel_input[idx, -1].sum() == 0 assert mel_input[idx, -1].sum() == 0
assert mel_input[idx, -2].sum() != 0, mel_input assert mel_input[idx, -2].sum() != 0, mel_input
assert linear_input[idx, -1].sum() == 0
assert linear_input[idx, -2].sum() != 0
assert stop_target[idx, -1] == 1 assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0 assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1 assert stop_target[idx].sum() == 1
@ -188,151 +336,202 @@ class TestLJSpeechDataset(unittest.TestCase):
# check the second itme in the batch # check the second itme in the batch
assert mel_input[1 - idx, -1].sum() == 0 assert mel_input[1 - idx, -1].sum() == 0
assert linear_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1 assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1 assert len(mel_lengths.shape) == 1
# check batch conditions # check batch conditions
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
class TestKusalDataset(unittest.TestCase): # class TestTTSDatasetMemory(unittest.TestCase):
def __init__(self, *args, **kwargs): # def __init__(self, *args, **kwargs):
super(TestKusalDataset, self).__init__(*args, **kwargs) # super(TestTTSDatasetMemory, self).__init__(*args, **kwargs)
self.max_loader_iter = 4 # self.max_loader_iter = 4
self.ap = AudioProcessor( # self.c = load_config(os.path.join(c.data_path_cache, 'config.json'))
sample_rate=c.sample_rate, # self.ap = AudioProcessor(**c.audio)
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis)
def test_loader(self): # def test_loader(self):
if ok_kusal: # if ok_ljspeech:
dataset = Kusal.MyDataset( # dataset = TTSDatasetMemory.MyDataset(
os.path.join(c.data_path_Kusal), # c.data_path_cache,
os.path.join(c.data_path_Kusal, 'prompts.txt'), # 'tts_metadata.csv',
c.r, # c.r,
c.text_cleaner, # c.text_cleaner,
ap=self.ap, # preprocessor=tts_cache,
min_seq_len=c.min_seq_len) # ap=self.ap,
# min_seq_len=c.min_seq_len)
dataloader = DataLoader( # dataloader = DataLoader(
dataset, # dataset,
batch_size=2, # batch_size=2,
shuffle=True, # shuffle=True,
collate_fn=dataset.collate_fn, # collate_fn=dataset.collate_fn,
drop_last=True, # drop_last=True,
num_workers=c.num_loader_workers) # num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): # for i, data in enumerate(dataloader):
if i == self.max_loader_iter: # if i == self.max_loader_iter:
break # break
text_input = data[0] # text_input = data[0]
text_lengths = data[1] # text_lengths = data[1]
linear_input = data[2] # linear_input = data[2]
mel_input = data[3] # mel_input = data[3]
mel_lengths = data[4] # mel_lengths = data[4]
stop_target = data[5] # stop_target = data[5]
item_idx = data[6] # item_idx = data[6]
neg_values = text_input[text_input < 0] # neg_values = text_input[text_input < 0]
check_count = len(neg_values) # check_count = len(neg_values)
assert check_count == 0, \ # assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count) # " !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here # # check mel-spec shape
assert linear_input.shape[0] == c.batch_size # assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[0] == c.batch_size # assert mel_input.shape[2] == c.audio['num_mels']
assert mel_input.shape[2] == c.num_mels # assert mel_input.max() <= self.ap.max_norm
# # check data range
# if self.ap.symmetric_norm:
# assert mel_input.max() <= self.ap.max_norm
# assert mel_input.min() >= -self.ap.max_norm
# assert mel_input.min() < 0
# else:
# assert mel_input.max() <= self.ap.max_norm
# assert mel_input.min() >= 0
def test_padding(self): # def test_batch_group_shuffle(self):
if ok_kusal: # if ok_ljspeech:
dataset = Kusal.MyDataset( # dataset = TTSDatasetMemory.MyDataset(
os.path.join(c.data_path_Kusal), # c.data_path_cache,
os.path.join(c.data_path_Kusal, 'prompts.txt'), # 'tts_metadata.csv',
1, # c.r,
c.text_cleaner, # c.text_cleaner,
ap=self.ap, # preprocessor=ljspeech,
min_seq_len=c.min_seq_len) # ap=self.ap,
# batch_group_size=16,
# min_seq_len=c.min_seq_len)
# Test for batch size 1 # dataloader = DataLoader(
dataloader = DataLoader( # dataset,
dataset, # batch_size=2,
batch_size=1, # shuffle=True,
shuffle=False, # collate_fn=dataset.collate_fn,
collate_fn=dataset.collate_fn, # drop_last=True,
drop_last=True, # num_workers=c.num_loader_workers)
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): # frames = dataset.items
if i == self.max_loader_iter: # for i, data in enumerate(dataloader):
break # if i == self.max_loader_iter:
text_input = data[0] # break
text_lengths = data[1] # text_input = data[0]
linear_input = data[2] # text_lengths = data[1]
mel_input = data[3] # linear_input = data[2]
mel_lengths = data[4] # mel_input = data[3]
stop_target = data[5] # mel_lengths = data[4]
item_idx = data[6] # stop_target = data[5]
# item_idx = data[6]
# check the last time step to be zero padded # neg_values = text_input[text_input < 0]
assert mel_input[0, -1].sum() == 0 # check_count = len(neg_values)
# assert mel_input[0, -2].sum() != 0 # assert check_count == 0, \
assert linear_input[0, -1].sum() == 0 # " !! Negative values in text_input: {}".format(check_count)
# assert linear_input[0, -2].sum() != 0 # assert mel_input.shape[0] == c.batch_size
assert stop_target[0, -1] == 1 # assert mel_input.shape[2] == c.audio['num_mels']
assert stop_target[0, -2] == 0 # dataloader.dataset.sort_items()
assert stop_target.sum() == 1 # assert frames[0] != dataloader.dataset.items[0]
assert len(mel_lengths.shape) == 1
assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2 # def test_padding_and_spec(self):
dataloader = DataLoader( # if ok_ljspeech:
dataset, # dataset = TTSDatasetMemory.MyDataset(
batch_size=2, # c.data_path_cache,
shuffle=False, # 'tts_meta_data.csv',
collate_fn=dataset.collate_fn, # 1,
drop_last=False, # c.text_cleaner,
num_workers=c.num_loader_workers) # preprocessor=ljspeech,
# ap=self.ap,
# min_seq_len=c.min_seq_len)
for i, data in enumerate(dataloader): # # Test for batch size 1
if i == self.max_loader_iter: # dataloader = DataLoader(
break # dataset,
text_input = data[0] # batch_size=1,
text_lengths = data[1] # shuffle=False,
linear_input = data[2] # collate_fn=dataset.collate_fn,
mel_input = data[3] # drop_last=True,
mel_lengths = data[4] # num_workers=c.num_loader_workers)
stop_target = data[5]
item_idx = data[6]
if mel_lengths[0] > mel_lengths[1]: # for i, data in enumerate(dataloader):
idx = 0 # if i == self.max_loader_iter:
else: # break
idx = 1 # text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
# check the first item in the batch # # check mel_spec consistency
assert mel_input[idx, -1].sum() == 0 # if item_idx[0].split('.')[-1] == 'npy':
assert mel_input[idx, -2].sum() != 0, mel_input # wav = np.load(item_idx[0])
assert linear_input[idx, -1].sum() == 0 # else:
assert linear_input[idx, -2].sum() != 0 # wav = self.ap.load_wav(item_idx[0])
assert stop_target[idx, -1] == 1 # mel = self.ap.melspectrogram(wav)
assert stop_target[idx, -2] == 0 # mel_dl = mel_input[0].cpu().numpy()
assert stop_target[idx].sum() == 1 # assert (
assert len(mel_lengths.shape) == 1 # abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0
assert mel_lengths[idx] == mel_input[idx].shape[0]
# check the second itme in the batch # # check mel-spec correctness
assert mel_input[1 - idx, -1].sum() == 0 # mel_spec = mel_input[0].cpu().numpy()
assert linear_input[1 - idx, -1].sum() == 0 # wav = self.ap.inv_mel_spectrogram(mel_spec.T)
assert stop_target[1 - idx, -1] == 1 # self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader_memo.wav')
assert len(mel_lengths.shape) == 1 # shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader_memo.wav')
# check batch conditions # # check the last time step to be zero padded
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 # assert mel_input[0, -1].sum() == 0
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 # assert mel_input[0, -2].sum() != 0
# assert stop_target[0, -1] == 1
# assert stop_target[0, -2] == 0
# assert stop_target.sum() == 1
# assert len(mel_lengths.shape) == 1
# assert mel_lengths[0] == mel_input[0].shape[0]
# # Test for batch size 2
# dataloader = DataLoader(
# dataset,
# batch_size=2,
# shuffle=False,
# collate_fn=dataset.collate_fn,
# drop_last=False,
# num_workers=c.num_loader_workers)
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
# if mel_lengths[0] > mel_lengths[1]:
# idx = 0
# else:
# idx = 1
# # check the first item in the batch
# assert mel_input[idx, -1].sum() == 0
# assert mel_input[idx, -2].sum() != 0, mel_input
# assert stop_target[idx, -1] == 1
# assert stop_target[idx, -2] == 0
# assert stop_target[idx].sum() == 1
# assert len(mel_lengths.shape) == 1
# assert mel_lengths[idx] == mel_input[idx].shape[0]
# # check the second itme in the batch
# assert mel_input[1 - idx, -1].sum() == 0
# assert stop_target[1 - idx, -1] == 1
# assert len(mel_lengths.shape) == 1
# # check batch conditions
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0

View File

@ -21,8 +21,8 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
class TacotronTrainTest(unittest.TestCase): class TacotronTrainTest(unittest.TestCase):
def test_train_step(self): def test_train_step(self):
input = torch.randint(0, 24, (8, 128)).long().to(device) input = torch.randint(0, 24, (8, 128)).long().to(device)
mel_spec = torch.rand(8, 30, c.num_mels).to(device) mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
linear_spec = torch.rand(8, 30, c.num_freq).to(device) linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device) mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
stop_targets = torch.zeros(8, 30, 1).float().to(device) stop_targets = torch.zeros(8, 30, 1).float().to(device)
@ -35,7 +35,7 @@ class TacotronTrainTest(unittest.TestCase):
criterion = L1LossMasked().to(device) criterion = L1LossMasked().to(device)
criterion_st = nn.BCELoss().to(device) criterion_st = nn.BCELoss().to(device)
model = Tacotron(c.embedding_size, c.num_freq, c.num_mels, model = Tacotron(c.embedding_size, c.audio['num_freq'], c.audio['num_mels'],
c.r).to(device) c.r).to(device)
model.train() model.train()
model_ref = copy.deepcopy(model) model_ref = copy.deepcopy(model)

View File

@ -1,16 +1,25 @@
{ {
"num_mels": 80, "audio":{
"num_freq": 1025, "audio_processor": "audio", // to use dictate different audio processors, if available.
"sample_rate": 22050, "num_mels": 80, // size of the mel spec frame.
"frame_length_ms": 50, "num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
"frame_shift_ms": 12.5, "sample_rate": 22050, // wav sample-rate. If different than the original data, it is resampled.
"preemphasis": 0.97, "frame_length_ms": 50, // stft window length in ms.
"min_level_db": -100, "frame_shift_ms": 12.5, // stft window hop-lengh in ms.
"ref_level_db": 20, "preemphasis": 0.97, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"min_level_db": -100, // normalization range
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
"power": 1.5, // value to sharpen wav signals after GL algorithm.
"griffin_lim_iters": 30,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
"signal_norm": true, // normalize the spec values in range [0, 1]
"symmetric_norm": true, // move normalization to range [-1, 1]
"clip_norm": true, // clip normalized values into the range.
"max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"mel_fmin": 95, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 7600 // maximum freq level for mel-spec. Tune for dataset!!
},
"hidden_size": 128, "hidden_size": 128,
"embedding_size": 256, "embedding_size": 256,
"min_mel_freq": null,
"max_mel_freq": null,
"text_cleaner": "english_cleaners", "text_cleaner": "english_cleaners",
"epochs": 2000, "epochs": 2000,
@ -21,16 +30,11 @@
"r": 5, "r": 5,
"mk": 1.0, "mk": 1.0,
"priority_freq": false, "priority_freq": false,
"griffin_lim_iters": 60,
"power": 1.5,
"num_loader_workers": 4, "num_loader_workers": 4,
"save_step": 200, "save_step": 200,
"data_path_LJSpeech": "/home/erogol/Data/LJSpeech-1.1", "data_path": "/home/erogol/Data/LJSpeech-1.1/",
"data_path_Kusal": "/home/erogol/Data/Kusal", "data_path_cache": "/home/erogol/Data/LJSpeech-1.1/tts_cache/",
"output_path": "result", "output_path": "result",
"min_seq_len": 0, "min_seq_len": 0,
"log_dir": "/home/erogol/projects/TTS/logs/" "log_dir": "/home/erogol/projects/TTS/logs/"

View File

@ -14,16 +14,16 @@ from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from utils.generic_utils import ( from utils.generic_utils import (
synthesis, remove_experiment_folder, create_experiment_folder, remove_experiment_folder, create_experiment_folder,
save_checkpoint, save_best_model, load_config, lr_decay, count_parameters, save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
check_update, get_commit_hash, sequence_mask, AnnealLR) check_update, get_commit_hash, sequence_mask, AnnealLR)
from utils.visual import plot_alignment, plot_spectrogram from utils.visual import plot_alignment, plot_spectrogram
from models.tacotron import Tacotron from models.tacotron import Tacotron
from layers.losses import L1LossMasked from layers.losses import L1LossMasked
from utils.audio import AudioProcessor from utils.audio import AudioProcessor
from utils.synthesis import synthesis
torch.manual_seed(1) torch.manual_seed(1)
# torch.set_num_threads(4)
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
@ -36,7 +36,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
avg_stop_loss = 0 avg_stop_loss = 0
avg_step_time = 0 avg_step_time = 0
print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True) print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True)
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) n_priority_freq = int(3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
batch_n_iter = int(len(data_loader.dataset) / c.batch_size) batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
for num_iter, data in enumerate(data_loader): for num_iter, data in enumerate(data_loader):
start_time = time.time() start_time = time.time()
@ -215,7 +215,7 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
"I'm sorry Dave. I'm afraid I can't do that.", "I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist." "This cake is great. It's so delicious and moist."
] ]
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) n_priority_freq = int(3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
with torch.no_grad(): with torch.no_grad():
if data_loader is not None: if data_loader is not None:
for num_iter, data in enumerate(data_loader): for num_iter, data in enumerate(data_loader):
@ -277,6 +277,7 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
const_spec = linear_output[idx].data.cpu().numpy() const_spec = linear_output[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy() gt_spec = linear_input[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, ap) const_spec = plot_spectrogram(const_spec, ap)
gt_spec = plot_spectrogram(gt_spec, ap) gt_spec = plot_spectrogram(gt_spec, ap)
align_img = plot_alignment(align_img) align_img = plot_alignment(align_img)
@ -319,8 +320,8 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
ap.griffin_lim_iters = 60 ap.griffin_lim_iters = 60
for idx, test_sentence in enumerate(test_sentences): for idx, test_sentence in enumerate(test_sentences):
try: try:
wav, linear_spec, alignments = synthesis(model, ap, test_sentence, wav, alignment, linear_spec, stop_tokens = synthesis(model, test_sentence, c,
use_cuda, c.text_cleaner) use_cuda, ap)
file_path = os.path.join(AUDIO_PATH, str(current_step)) file_path = os.path.join(AUDIO_PATH, str(current_step))
os.makedirs(file_path, exist_ok=True) os.makedirs(file_path, exist_ok=True)
@ -330,7 +331,7 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
wav_name = 'TestSentences/{}'.format(idx) wav_name = 'TestSentences/{}'.format(idx)
tb.add_audio( tb.add_audio(
wav_name, wav, current_step, sample_rate=c.sample_rate) wav_name, wav, current_step, sample_rate=c.audio['sample_rate'])
align_img = alignments[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy()
linear_spec = plot_spectrogram(linear_spec, ap) linear_spec = plot_spectrogram(linear_spec, ap)
align_img = plot_alignment(align_img) align_img = plot_alignment(align_img)
@ -345,28 +346,22 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
def main(args): def main(args):
dataset = importlib.import_module('datasets.' + c.dataset) preprocessor = importlib.import_module('datasets.preprocess')
Dataset = getattr(dataset, 'MyDataset') preprocessor = getattr(preprocessor, c.dataset.lower())
audio = importlib.import_module('utils.' + c.audio_processor) MyDataset = importlib.import_module('datasets.'+c.data_loader)
MyDataset = getattr(MyDataset, "MyDataset")
audio = importlib.import_module('utils.' + c.audio['audio_processor'])
AudioProcessor = getattr(audio, 'AudioProcessor') AudioProcessor = getattr(audio, 'AudioProcessor')
ap = AudioProcessor( ap = AudioProcessor(**c.audio)
sample_rate=c.sample_rate,
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis)
# Setup the dataset # Setup the dataset
train_dataset = Dataset( train_dataset = MyDataset(
c.data_path, c.data_path,
c.meta_file_train, c.meta_file_train,
c.r, c.r,
c.text_cleaner, c.text_cleaner,
preprocessor=preprocessor,
ap=ap, ap=ap,
batch_group_size=8*c.batch_size, batch_group_size=8*c.batch_size,
min_seq_len=c.min_seq_len) min_seq_len=c.min_seq_len)
@ -381,8 +376,8 @@ def main(args):
pin_memory=True) pin_memory=True)
if c.run_eval: if c.run_eval:
val_dataset = Dataset( val_dataset = MyDataset(
c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap, batch_group_size=0) c.data_path, c.meta_file_val, c.r, c.text_cleaner, preprocessor=preprocessor, ap=ap, batch_group_size=0)
val_loader = DataLoader( val_loader = DataLoader(
val_dataset, val_dataset,
@ -395,7 +390,7 @@ def main(args):
else: else:
val_loader = None val_loader = None
model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r) model = Tacotron(c.embedding_size, ap.num_freq, ap.num_mels, c.r)
print(" | > Num output units : {}".format(ap.num_freq), flush=True) print(" | > Num output units : {}".format(ap.num_freq), flush=True)
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0) optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
@ -431,7 +426,7 @@ def main(args):
criterion.cuda() criterion.cuda()
criterion_st.cuda() criterion_st.cuda()
scheduler = AnnealLR(optimizer, warmup_steps=c.warmup_steps, last_epoch= (args.restore_step-1)) scheduler = AnnealLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
num_params = count_parameters(model) num_params = count_parameters(model)
print(" | > Model has {} parameters".format(num_params), flush=True) print(" | > Model has {} parameters".format(num_params), flush=True)
@ -454,7 +449,7 @@ def main(args):
best_loss = save_best_model(model, optimizer, train_loss, best_loss, best_loss = save_best_model(model, optimizer, train_loss, best_loss,
OUT_PATH, current_step, epoch) OUT_PATH, current_step, epoch)
# shuffle batch groups # shuffle batch groups
train_loader.dataset.sort_frames() train_loader.dataset.sort_items()
if __name__ == '__main__': if __name__ == '__main__':
@ -477,8 +472,9 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
'--data_path', '--data_path',
type=str, type=str,
default='', help='dataset path.',
help='data path to overrite config.json') default=''
)
args = parser.parse_args() args = parser.parse_args()
# setup output paths and read configs # setup output paths and read configs
@ -491,7 +487,7 @@ if __name__ == '__main__':
os.makedirs(AUDIO_PATH, exist_ok=True) os.makedirs(AUDIO_PATH, exist_ok=True)
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json')) shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
if args.data_path != "": if args.data_path != '':
c.data_path = args.data_path c.data_path = args.data_path
# setup tensorboard # setup tensorboard

View File

@ -3,26 +3,34 @@ import librosa
import pickle import pickle
import copy import copy
import numpy as np import numpy as np
import scipy from pprint import pprint
from scipy import signal from scipy import signal, io
_mel_basis = None
class AudioProcessor(object): class AudioProcessor(object):
def __init__(self, def __init__(self,
sample_rate, bits=None,
num_mels, sample_rate=None,
min_level_db, num_mels=None,
frame_shift_ms, min_level_db=None,
frame_length_ms, frame_shift_ms=None,
ref_level_db, frame_length_ms=None,
num_freq, ref_level_db=None,
power, num_freq=None,
preemphasis, power=None,
griffin_lim_iters=None): preemphasis=None,
signal_norm=None,
symmetric_norm=None,
max_norm=None,
mel_fmin=None,
mel_fmax=None,
clip_norm=True,
griffin_lim_iters=None,
**kwargs):
print(" > Setting up Audio Processor...") print(" > Setting up Audio Processor...")
self.bits = bits
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.num_mels = num_mels self.num_mels = num_mels
self.min_level_db = min_level_db self.min_level_db = min_level_db
@ -33,33 +41,79 @@ class AudioProcessor(object):
self.power = power self.power = power
self.preemphasis = preemphasis self.preemphasis = preemphasis
self.griffin_lim_iters = griffin_lim_iters self.griffin_lim_iters = griffin_lim_iters
self.signal_norm = signal_norm
self.symmetric_norm = symmetric_norm
self.mel_fmin = 0 if mel_fmin is None else mel_fmin
self.mel_fmax = mel_fmax
self.max_norm = 1.0 if max_norm is None else float(max_norm)
self.clip_norm = clip_norm
self.n_fft, self.hop_length, self.win_length = self._stft_parameters() self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
if preemphasis == 0: if preemphasis == 0:
print(" | > Preemphasis is deactive.") print(" | > Preemphasis is deactive.")
print(" | > Audio Processor attributes.")
members = vars(self)
pprint(members)
def save_wav(self, wav, path): def save_wav(self, wav, path):
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
# librosa.output.write_wav(path, wav_norm.astype(np.int16), self.sample_rate) io.wavfile.write(path, self.sample_rate, wav_norm.astype(np.int16))
scipy.io.wavfile.write(path, self.sample_rate, wav_norm.astype(np.int16))
def _linear_to_mel(self, spectrogram): def _linear_to_mel(self, spectrogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = self._build_mel_basis() _mel_basis = self._build_mel_basis()
return np.dot(_mel_basis, spectrogram) return np.dot(_mel_basis, spectrogram)
def _mel_to_linear(self, mel_spec):
inv_mel_basis = np.linalg.pinv(self._build_mel_basis())
return np.maximum(1e-10, np.dot(inv_mel_basis, mel_spec))
def _build_mel_basis(self, ): def _build_mel_basis(self, ):
n_fft = (self.num_freq - 1) * 2 n_fft = (self.num_freq - 1) * 2
if self.mel_fmax is not None:
assert self.mel_fmax <= self.sample_rate // 2
return librosa.filters.mel( return librosa.filters.mel(
self.sample_rate, n_fft, n_mels=self.num_mels) self.sample_rate,
n_fft,
n_mels=self.num_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax)
def _normalize(self, S): def _normalize(self, S):
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1) """Put values in [0, self.max_norm] or [-self.max_norm, self.max_norm]"""
if self.signal_norm:
S_norm = ((S - self.min_level_db) / - self.min_level_db)
if self.symmetric_norm:
S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm
if self.clip_norm :
S_norm = np.clip(S_norm, -self.max_norm, self.max_norm)
return S_norm
else:
S_norm = self.max_norm * S_norm
if self.clip_norm:
S_norm = np.clip(S_norm, 0, self.max_norm)
return S_norm
else:
return S
def _denormalize(self, S): def _denormalize(self, S):
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db """denormalize values"""
S_denorm = S
if self.signal_norm:
if self.symmetric_norm:
if self.clip_norm:
S_denorm = np.clip(S_denorm, -self.max_norm, self.max_norm)
S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db
return S_denorm
else:
if self.clip_norm:
S_denorm = np.clip(S_denorm, 0, self.max_norm)
S_denorm = (S_denorm * -self.min_level_db /
self.max_norm) + self.min_level_db
return S_denorm
else:
return S
def _stft_parameters(self, ): def _stft_parameters(self, ):
"""Compute necessary stft parameters with given time values"""
n_fft = (self.num_freq - 1) * 2 n_fft = (self.num_freq - 1) * 2
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate) win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
@ -92,8 +146,16 @@ class AudioProcessor(object):
S = self._amp_to_db(np.abs(D)) - self.ref_level_db S = self._amp_to_db(np.abs(D)) - self.ref_level_db
return self._normalize(S) return self._normalize(S)
def melspectrogram(self, y):
if self.preemphasis != 0:
D = self._stft(self.apply_preemphasis(y))
else:
D = self._stft(y)
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
return self._normalize(S)
def inv_spectrogram(self, spectrogram): def inv_spectrogram(self, spectrogram):
'''Converts spectrogram to waveform using librosa''' """Converts spectrogram to waveform using librosa"""
S = self._denormalize(spectrogram) S = self._denormalize(spectrogram)
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
# Reconstruct phase # Reconstruct phase
@ -102,6 +164,16 @@ class AudioProcessor(object):
else: else:
return self._griffin_lim(S**self.power) return self._griffin_lim(S**self.power)
def inv_mel_spectrogram(self, mel_spectrogram):
'''Converts mel spectrogram to waveform using librosa'''
D = self._denormalize(mel_spectrogram)
S = self._db_to_amp(D + self.ref_level_db)
S = self._mel_to_linear(S) # Convert back to linear
if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
else:
return self._griffin_lim(S**self.power)
def _griffin_lim(self, S): def _griffin_lim(self, S):
angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
S_complex = np.abs(S).astype(np.complex) S_complex = np.abs(S).astype(np.complex)
@ -111,20 +183,17 @@ class AudioProcessor(object):
y = self._istft(S_complex * angles) y = self._istft(S_complex * angles)
return y return y
def melspectrogram(self, y):
if self.preemphasis != 0:
D = self._stft(self.apply_preemphasis(y))
else:
D = self._stft(y)
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
return self._normalize(S)
def _stft(self, y): def _stft(self, y):
return librosa.stft( return librosa.stft(
y=y, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length) y=y,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
)
def _istft(self, y): def _istft(self, y):
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length) return librosa.istft(
y, hop_length=self.hop_length, win_length=self.win_length)
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8): def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
window_length = int(self.sample_rate * min_silence_sec) window_length = int(self.sample_rate * min_silence_sec)
@ -134,3 +203,37 @@ class AudioProcessor(object):
if np.max(wav[x:x + window_length]) < threshold: if np.max(wav[x:x + window_length]) < threshold:
return x + hop_length return x + hop_length
return len(wav) return len(wav)
# WaveRNN repo specific functions
# def mulaw_encode(self, wav, qc):
# mu = qc - 1
# wav_abs = np.minimum(np.abs(wav), 1.0)
# magnitude = np.log(1 + mu * wav_abs) / np.log(1. + mu)
# signal = np.sign(wav) * magnitude
# # Quantize signal to the specified number of levels.
# signal = (signal + 1) / 2 * mu + 0.5
# return signal.astype(np.int32)
# def mulaw_decode(self, wav, qc):
# """Recovers waveform from quantized values."""
# mu = qc - 1
# # Map values back to [-1, 1].
# casted = wav.astype(np.float32)
# signal = 2 * (casted / mu) - 1
# # Perform inverse of mu-law transformation.
# magnitude = (1 / mu) * ((1 + mu) ** abs(signal) - 1)
# return np.sign(signal) * magnitude
def load_wav(self, filename, encode=False):
x, sr = librosa.load(filename, sr=self.sample_rate)
assert self.sample_rate == sr
return x
def encode_16bits(self, x):
return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)
def quantize(self, x):
return (x + 1.) * (2**self.bits - 1) / 2
def dequantize(self, x):
return 2 * x / (2**self.bits - 1) - 1

View File

@ -1,4 +1,5 @@
import os import os
import re
import sys import sys
import glob import glob
import time import time
@ -21,18 +22,23 @@ class AttrDict(dict):
def load_config(config_path): def load_config(config_path):
config = AttrDict() config = AttrDict()
config.update(json.load(open(config_path, "r"))) with open(config_path, "r") as f:
input_str = f.read()
input_str = re.sub(r'\\\n', '', input_str)
input_str = re.sub(r'//.*\n', '\n', input_str)
data = json.loads(input_str)
config.update(data)
return config return config
def get_commit_hash(): def get_commit_hash():
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
try: # try:
subprocess.check_output(['git', 'diff-index', '--quiet', # subprocess.check_output(['git', 'diff-index', '--quiet',
'HEAD']) # Verify client is clean # 'HEAD']) # Verify client is clean
except: # except:
raise RuntimeError( # raise RuntimeError(
" !! Commit before training to get the commit hash.") # " !! Commit before training to get the commit hash.")
commit = subprocess.check_output(['git', 'rev-parse', '--short', commit = subprocess.check_output(['git', 'rev-parse', '--short',
'HEAD']).decode().strip() 'HEAD']).decode().strip()
print(' > Git Hash: {}'.format(commit)) print(' > Git Hash: {}'.format(commit))
@ -177,15 +183,3 @@ def sequence_mask(sequence_length, max_len=None):
seq_length_expand = (sequence_length.unsqueeze(1) seq_length_expand = (sequence_length.unsqueeze(1)
.expand_as(seq_range_expand)) .expand_as(seq_range_expand))
return seq_range_expand < seq_length_expand return seq_range_expand < seq_length_expand
def synthesis(model, ap, text, use_cuda, text_cleaner):
text_cleaner = [text_cleaner]
seq = np.array(text_to_sequence(text, text_cleaner))
chars_var = torch.from_numpy(seq).unsqueeze(0)
if use_cuda:
chars_var = chars_var.cuda().long()
_, linear_out, alignments, _ = model.forward(chars_var)
linear_out = linear_out[0].data.cpu().numpy()
wav = ap.inv_spectrogram(linear_out.T)
return wav, linear_out, alignments

23
utils/synthesis.py Normal file
View File

@ -0,0 +1,23 @@
import io
import time
import librosa
import torch
import numpy as np
from .text import text_to_sequence
from .visual import visualize
from matplotlib import pylab as plt
def synthesis(m, s, CONFIG, use_cuda, ap):
""" Given the text, synthesising the audio """
text_cleaner = [CONFIG.text_cleaner]
seq = np.array(text_to_sequence(s, text_cleaner))
chars_var = torch.from_numpy(seq).unsqueeze(0)
if use_cuda:
chars_var = chars_var.cuda()
mel_spec, linear_spec, alignments, stop_tokens = m.forward(chars_var.long())
linear_spec = linear_spec[0].data.cpu().numpy()
alignment = alignments[0].cpu().data.numpy()
wav = ap.inv_spectrogram(linear_spec.T)
# wav = wav[:ap.find_endpoint(wav)]
return wav, alignment, linear_spec, stop_tokens

View File

@ -1,4 +1,5 @@
import numpy as np import numpy as np
import librosa
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -14,6 +15,7 @@ def plot_alignment(alignment, info=None):
xlabel += '\n\n' + info xlabel += '\n\n' + info
plt.xlabel(xlabel) plt.xlabel(xlabel)
plt.ylabel('Encoder timestep') plt.ylabel('Encoder timestep')
# plt.yticks(range(len(text)), list(text))
plt.tight_layout() plt.tight_layout()
return fig return fig
@ -25,3 +27,28 @@ def plot_spectrogram(linear_output, audio):
plt.colorbar() plt.colorbar()
plt.tight_layout() plt.tight_layout()
return fig return fig
def visualize(alignment, spectrogram, stop_tokens, text, hop_length, CONFIG):
label_fontsize = 16
plt.figure(figsize=(16, 32))
plt.subplot(3, 1, 1)
plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None)
plt.xlabel("Decoder timestamp", fontsize=label_fontsize)
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
plt.yticks(range(len(text)), list(text))
plt.colorbar()
stop_tokens = stop_tokens.squeeze().detach().to('cpu').numpy()
plt.subplot(3, 1, 2)
plt.plot(range(len(stop_tokens)), list(stop_tokens))
plt.subplot(3, 1, 3)
librosa.display.specshow(spectrogram.T, sr=CONFIG.audio['sample_rate'],
hop_length=hop_length, x_axis="time", y_axis="linear")
plt.xlabel("Time", fontsize=label_fontsize)
plt.ylabel("Hz", fontsize=label_fontsize)
plt.tight_layout()
plt.colorbar()