pep8 format all

This commit is contained in:
Eren G 2018-08-02 16:34:17 +02:00
parent 6ebbae9534
commit 2721578c52
32 changed files with 766 additions and 599 deletions

View File

@ -8,21 +8,28 @@ import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from utils.text import text_to_sequence from utils.text import text_to_sequence
from utils.data import (prepare_data, pad_per_step, from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_tensor, prepare_stop_target) prepare_stop_target)
class MyDataset(Dataset): class MyDataset(Dataset):
def __init__(self,
def __init__(self, root_dir, csv_file, outputs_per_step, root_dir,
text_cleaner, ap, min_seq_len=0): csv_file,
outputs_per_step,
text_cleaner,
ap,
min_seq_len=0):
self.root_dir = root_dir self.root_dir = root_dir
self.wav_dir = os.path.join(root_dir, 'wav') self.wav_dir = os.path.join(root_dir, 'wav')
self.wav_files = glob.glob(os.path.join(self.wav_dir, '*.wav')) self.wav_files = glob.glob(os.path.join(self.wav_dir, '*.wav'))
self._create_file_dict() self._create_file_dict()
self.csv_dir = os.path.join(root_dir, csv_file) self.csv_dir = os.path.join(root_dir, csv_file)
with open(self.csv_dir, "r", encoding="utf8") as f: with open(self.csv_dir, "r", encoding="utf8") as f:
self.frames = [line.split('\t') for line in f if line.split('\t')[0] in self.wav_files_dict.keys()] self.frames = [
line.split('\t') for line in f
if line.split('\t')[0] in self.wav_files_dict.keys()
]
self.outputs_per_step = outputs_per_step self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner self.cleaners = text_cleaner
@ -43,10 +50,8 @@ class MyDataset(Dataset):
print(" !! Cannot read file : {}".format(filename)) print(" !! Cannot read file : {}".format(filename))
def _trim_silence(self, wav): def _trim_silence(self, wav):
return librosa.effects.trim( return librosa.effects.trim(
wav, top_db=40, wav, top_db=40, frame_length=1024, hop_length=256)[0]
frame_length=1024,
hop_length=256)[0]
def _create_file_dict(self): def _create_file_dict(self):
self.wav_files_dict = {} self.wav_files_dict = {}
@ -87,11 +92,10 @@ class MyDataset(Dataset):
sidx = self.frames[idx][0] sidx = self.frames[idx][0]
sidx_files = self.wav_files_dict[sidx] sidx_files = self.wav_files_dict[sidx]
file_name = random.choice(sidx_files) file_name = random.choice(sidx_files)
wav_name = os.path.join(self.wav_dir, wav_name = os.path.join(self.wav_dir, file_name)
file_name)
text = self.frames[idx][2] text = self.frames[idx][2]
text = np.asarray(text_to_sequence( text = np.asarray(
text, [self.cleaners]), dtype=np.int32) text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32) wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]} sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
return sample return sample
@ -121,12 +125,13 @@ class MyDataset(Dataset):
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets # compute 'stop token' targets
stop_targets = [np.array([0.]*(mel_len-1)) stop_targets = [
for mel_len in mel_lengths] np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
# PAD stop targets # PAD stop targets
stop_targets = prepare_stop_target( stop_targets = prepare_stop_target(stop_targets,
stop_targets, self.outputs_per_step) self.outputs_per_step)
# PAD sequences with largest length of the batch # PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32) text = prepare_data(text).astype(np.int32)
@ -150,8 +155,8 @@ class MyDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths) mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets) stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0] return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}" found {}".format(type(batch[0]))))
.format(type(batch[0]))))

View File

@ -6,14 +6,18 @@ import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from utils.text import text_to_sequence from utils.text import text_to_sequence
from utils.data import (prepare_data, pad_per_step, from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_tensor, prepare_stop_target) prepare_stop_target)
class MyDataset(Dataset): class MyDataset(Dataset):
def __init__(self,
def __init__(self, root_dir, csv_file, outputs_per_step, root_dir,
text_cleaner, ap, min_seq_len=0): csv_file,
outputs_per_step,
text_cleaner,
ap,
min_seq_len=0):
self.root_dir = root_dir self.root_dir = root_dir
self.wav_dir = os.path.join(root_dir, 'wavs') self.wav_dir = os.path.join(root_dir, 'wavs')
self.csv_dir = os.path.join(root_dir, csv_file) self.csv_dir = os.path.join(root_dir, csv_file)
@ -60,11 +64,10 @@ class MyDataset(Dataset):
return len(self.frames) return len(self.frames)
def __getitem__(self, idx): def __getitem__(self, idx):
wav_name = os.path.join(self.wav_dir, wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
self.frames[idx][0]) + '.wav'
text = self.frames[idx][1] text = self.frames[idx][1]
text = np.asarray(text_to_sequence( text = np.asarray(
text, [self.cleaners]), dtype=np.int32) text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32) wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]} sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
return sample return sample
@ -94,12 +97,13 @@ class MyDataset(Dataset):
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets # compute 'stop token' targets
stop_targets = [np.array([0.]*(mel_len-1)) stop_targets = [
for mel_len in mel_lengths] np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
# PAD stop targets # PAD stop targets
stop_targets = prepare_stop_target( stop_targets = prepare_stop_target(stop_targets,
stop_targets, self.outputs_per_step) self.outputs_per_step)
# PAD sequences with largest length of the batch # PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32) text = prepare_data(text).astype(np.int32)
@ -123,8 +127,8 @@ class MyDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths) mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets) stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0] return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}" found {}".format(type(batch[0]))))
.format(type(batch[0]))))

View File

@ -6,14 +6,18 @@ import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from utils.text import text_to_sequence from utils.text import text_to_sequence
from utils.data import (prepare_data, pad_per_step, from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_tensor, prepare_stop_target) prepare_stop_target)
class MyDataset(Dataset): class MyDataset(Dataset):
def __init__(self,
def __init__(self, root_dir, csv_file, outputs_per_step, root_dir,
text_cleaner, ap, min_seq_len=0): csv_file,
outputs_per_step,
text_cleaner,
ap,
min_seq_len=0):
self.root_dir = root_dir self.root_dir = root_dir
self.wav_dir = os.path.join(root_dir, 'wavs') self.wav_dir = os.path.join(root_dir, 'wavs')
self.feat_dir = os.path.join(root_dir, 'loader_data') self.feat_dir = os.path.join(root_dir, 'loader_data')
@ -66,20 +70,24 @@ class MyDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
if self.items[idx] is None: if self.items[idx] is None:
wav_name = os.path.join(self.wav_dir, wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
self.frames[idx][0]) + '.wav'
mel_name = os.path.join(self.feat_dir, mel_name = os.path.join(self.feat_dir,
self.frames[idx][0]) + '.mel.npy' self.frames[idx][0]) + '.mel.npy'
linear_name = os.path.join(self.feat_dir, linear_name = os.path.join(self.feat_dir,
self.frames[idx][0]) + '.linear.npy' self.frames[idx][0]) + '.linear.npy'
text = self.frames[idx][1] text = self.frames[idx][1]
text = np.asarray(text_to_sequence( text = np.asarray(
text, [self.cleaners]), dtype=np.int32) text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32) wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
mel = self.load_np(mel_name) mel = self.load_np(mel_name)
linear = self.load_np(linear_name) linear = self.load_np(linear_name)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0], sample = {
'mel':mel, 'linear': linear} 'text': text,
'wav': wav,
'item_idx': self.frames[idx][0],
'mel': mel,
'linear': linear
}
self.items[idx] = sample self.items[idx] = sample
else: else:
sample = self.items[idx] sample = self.items[idx]
@ -109,12 +117,13 @@ class MyDataset(Dataset):
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets # compute 'stop token' targets
stop_targets = [np.array([0.]*(mel_len-1)) stop_targets = [
for mel_len in mel_lengths] np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
# PAD stop targets # PAD stop targets
stop_targets = prepare_stop_target( stop_targets = prepare_stop_target(stop_targets,
stop_targets, self.outputs_per_step) self.outputs_per_step)
# PAD sequences with largest length of the batch # PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32) text = prepare_data(text).astype(np.int32)
@ -138,8 +147,8 @@ class MyDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths) mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets) stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0] return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}" found {}".format(type(batch[0]))))
.format(type(batch[0]))))

View File

@ -7,15 +7,25 @@ from torch.utils.data import Dataset
from TTS.utils.text import text_to_sequence from TTS.utils.text import text_to_sequence
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.data import (prepare_data, pad_per_step, from TTS.utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_tensor, prepare_stop_target) prepare_stop_target)
class TWEBDataset(Dataset): class TWEBDataset(Dataset):
def __init__(self,
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate, csv_file,
text_cleaner, num_mels, min_level_db, frame_shift_ms, root_dir,
frame_length_ms, preemphasis, ref_level_db, num_freq, power, outputs_per_step,
sample_rate,
text_cleaner,
num_mels,
min_level_db,
frame_shift_ms,
frame_length_ms,
preemphasis,
ref_level_db,
num_freq,
power,
min_seq_len=0): min_seq_len=0):
with open(csv_file, "r") as f: with open(csv_file, "r") as f:
@ -25,8 +35,9 @@ class TWEBDataset(Dataset):
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.cleaners = text_cleaner self.cleaners = text_cleaner
self.min_seq_len = min_seq_len self.min_seq_len = min_seq_len
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms, self.ap = AudioProcessor(sample_rate, num_mels, min_level_db,
frame_length_ms, preemphasis, ref_level_db, num_freq, power) frame_shift_ms, frame_length_ms, preemphasis,
ref_level_db, num_freq, power)
print(" > Reading TWEB from - {}".format(root_dir)) print(" > Reading TWEB from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames))) print(" | > Number of instances : {}".format(len(self.frames)))
self._sort_frames() self._sort_frames()
@ -63,11 +74,10 @@ class TWEBDataset(Dataset):
return len(self.frames) return len(self.frames)
def __getitem__(self, idx): def __getitem__(self, idx):
wav_name = os.path.join(self.root_dir, wav_name = os.path.join(self.root_dir, self.frames[idx][0]) + '.wav'
self.frames[idx][0]) + '.wav'
text = self.frames[idx][1] text = self.frames[idx][1]
text = np.asarray(text_to_sequence( text = np.asarray(
text, [self.cleaners]), dtype=np.int32) text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32) wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]} sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
return sample return sample
@ -97,12 +107,13 @@ class TWEBDataset(Dataset):
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets # compute 'stop token' targets
stop_targets = [np.array([0.]*(mel_len-1)) stop_targets = [
for mel_len in mel_lengths] np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
# PAD stop targets # PAD stop targets
stop_targets = prepare_stop_target( stop_targets = prepare_stop_target(stop_targets,
stop_targets, self.outputs_per_step) self.outputs_per_step)
# PAD sequences with largest length of the batch # PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32) text = prepare_data(text).astype(np.int32)
@ -126,8 +137,8 @@ class TWEBDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths) mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets) stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0] return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}" found {}".format(type(batch[0]))))
.format(type(batch[0]))))

View File

@ -10,7 +10,6 @@
"hidden_size": 128, "hidden_size": 128,
"embedding_size": 256, "embedding_size": 256,
"text_cleaner": "english_cleaners", "text_cleaner": "english_cleaners",
"epochs": 200, "epochs": 200,
"lr": 0.01, "lr": 0.01,
"lr_patience": 2, "lr_patience": 2,
@ -19,9 +18,7 @@
"griffinf_lim_iters": 60, "griffinf_lim_iters": 60,
"power": 1.5, "power": 1.5,
"r": 5, "r": 5,
"num_loader_workers": 16, "num_loader_workers": 16,
"save_step": 1, "save_step": 1,
"data_path": "/data/shared/KeithIto/LJSpeech-1.0", "data_path": "/data/shared/KeithIto/LJSpeech-1.0",
"output_path": "result", "output_path": "result",

View File

@ -13,19 +13,19 @@ from utils.generic_utils import load_config
from multiprocessing import Pool from multiprocessing import Pool
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, parser.add_argument('--data_path', type=str, help='Data folder.')
help='Data folder.') parser.add_argument('--out_path', type=str, help='Output folder.')
parser.add_argument('--out_path', type=str, parser.add_argument(
help='Output folder.') '--config', type=str, help='conf.json file for run settings.')
parser.add_argument('--config', type=str, parser.add_argument(
help='conf.json file for run settings.') "--num_proc", type=int, default=8, help="number of processes.")
parser.add_argument("--num_proc", type=int, default=8, parser.add_argument(
help="number of processes.") "--trim_silence",
parser.add_argument("--trim_silence", type=bool, default=False, type=bool,
help="trim silence in the voice clip.") default=False,
help="trim silence in the voice clip.")
args = parser.parse_args() args = parser.parse_args()
DATA_PATH = args.data_path DATA_PATH = args.data_path
OUT_PATH = args.out_path OUT_PATH = args.out_path
@ -34,27 +34,26 @@ if __name__ == "__main__":
print(" > Input path: ", DATA_PATH) print(" > Input path: ", DATA_PATH)
print(" > Output path: ", OUT_PATH) print(" > Output path: ", OUT_PATH)
audio = importlib.import_module('utils.'+c.audio_processor) audio = importlib.import_module('utils.' + c.audio_processor)
AudioProcessor = getattr(audio, 'AudioProcessor') AudioProcessor = getattr(audio, 'AudioProcessor')
ap = AudioProcessor(sample_rate = CONFIG.sample_rate, ap = AudioProcessor(
num_mels = CONFIG.num_mels, sample_rate=CONFIG.sample_rate,
min_level_db = CONFIG.min_level_db, num_mels=CONFIG.num_mels,
frame_shift_ms = CONFIG.frame_shift_ms, min_level_db=CONFIG.min_level_db,
frame_length_ms = CONFIG.frame_length_ms, frame_shift_ms=CONFIG.frame_shift_ms,
ref_level_db = CONFIG.ref_level_db, frame_length_ms=CONFIG.frame_length_ms,
num_freq = CONFIG.num_freq, ref_level_db=CONFIG.ref_level_db,
power = CONFIG.power, num_freq=CONFIG.num_freq,
preemphasis = CONFIG.preemphasis, power=CONFIG.power,
min_mel_freq = CONFIG.min_mel_freq, preemphasis=CONFIG.preemphasis,
max_mel_freq = CONFIG.max_mel_freq) min_mel_freq=CONFIG.min_mel_freq,
max_mel_freq=CONFIG.max_mel_freq)
def trim_silence(self, wav): def trim_silence(self, wav):
margin = int(CONFIG.sample_rate * 0.1) margin = int(CONFIG.sample_rate * 0.1)
wav = wav[margin:-margin] wav = wav[margin:-margin]
return librosa.effects.trim( return librosa.effects.trim(
wav, top_db=40, wav, top_db=40, frame_length=1024, hop_length=256)[0]
frame_length=1024,
hop_length=256)[0]
def extract_mel(file_path): def extract_mel(file_path):
# x, fs = sf.read(file_path) # x, fs = sf.read(file_path)
@ -63,23 +62,25 @@ if __name__ == "__main__":
x = trim_silence(x) x = trim_silence(x)
mel = ap.melspectrogram(x.astype('float32')).astype('float32') mel = ap.melspectrogram(x.astype('float32')).astype('float32')
linear = ap.spectrogram(x.astype('float32')).astype('float32') linear = ap.spectrogram(x.astype('float32')).astype('float32')
file_name = os.path.basename(file_path).replace(".wav","") file_name = os.path.basename(file_path).replace(".wav", "")
mel_file = file_name + ".mel" mel_file = file_name + ".mel"
linear_file = file_name + ".linear" linear_file = file_name + ".linear"
np.save(os.path.join(OUT_PATH, mel_file), mel, allow_pickle=False) np.save(os.path.join(OUT_PATH, mel_file), mel, allow_pickle=False)
np.save(os.path.join(OUT_PATH, linear_file), linear, allow_pickle=False) np.save(
os.path.join(OUT_PATH, linear_file), linear, allow_pickle=False)
mel_len = mel.shape[1] mel_len = mel.shape[1]
linear_len = linear.shape[1] linear_len = linear.shape[1]
wav_len = x.shape[0] wav_len = x.shape[0]
print(" > " + file_path, flush=True) print(" > " + file_path, flush=True)
return file_path, mel_file, linear_file, str(wav_len), str(mel_len), str(linear_len) return file_path, mel_file, linear_file, str(wav_len), str(
mel_len), str(linear_len)
glob_path = os.path.join(DATA_PATH, "*.wav") glob_path = os.path.join(DATA_PATH, "*.wav")
print(" > Reading wav: {}".format(glob_path)) print(" > Reading wav: {}".format(glob_path))
file_names = glob.glob(glob_path, recursive=True) file_names = glob.glob(glob_path, recursive=True)
if __name__ == "__main__": if __name__ == "__main__":
print(" > Number of files: %i"%(len(file_names))) print(" > Number of files: %i" % (len(file_names)))
if not os.path.exists(OUT_PATH): if not os.path.exists(OUT_PATH):
os.makedirs(OUT_PATH) os.makedirs(OUT_PATH)
print(" > A new folder created at {}".format(OUT_PATH)) print(" > A new folder created at {}".format(OUT_PATH))
@ -88,7 +89,10 @@ if __name__ == "__main__":
if args.num_proc > 1: if args.num_proc > 1:
print(" > Using {} processes.".format(args.num_proc)) print(" > Using {} processes.".format(args.num_proc))
with Pool(args.num_proc) as p: with Pool(args.num_proc) as p:
r = list(tqdm.tqdm(p.imap(extract_mel, file_names), total=len(file_names))) r = list(
tqdm.tqdm(
p.imap(extract_mel, file_names),
total=len(file_names)))
# r = list(p.imap(extract_mel, file_names)) # r = list(p.imap(extract_mel, file_names))
else: else:
print(" > Using single process run.") print(" > Using single process run.")
@ -100,5 +104,5 @@ if __name__ == "__main__":
file = open(file_path, "w") file = open(file_path, "w")
for line in r: for line in r:
line = ", ".join(line) line = ", ".join(line)
file.write(line+'\n') file.write(line + '\n')
file.close() file.close()

View File

@ -24,8 +24,8 @@ class BahdanauAttention(nn.Module):
processed_query = self.query_layer(query) processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annots) processed_annots = self.annot_layer(annots)
# (batch, max_time, 1) # (batch, max_time, 1)
alignment = self.v(nn.functional.tanh( alignment = self.v(
processed_query + processed_annots)) nn.functional.tanh(processed_query + processed_annots))
# (batch, max_time) # (batch, max_time)
return alignment.squeeze(-1) return alignment.squeeze(-1)
@ -33,15 +33,24 @@ class BahdanauAttention(nn.Module):
class LocationSensitiveAttention(nn.Module): class LocationSensitiveAttention(nn.Module):
"""Location sensitive attention following """Location sensitive attention following
https://arxiv.org/pdf/1506.07503.pdf""" https://arxiv.org/pdf/1506.07503.pdf"""
def __init__(self, annot_dim, query_dim, attn_dim,
kernel_size=7, filters=20): def __init__(self,
annot_dim,
query_dim,
attn_dim,
kernel_size=7,
filters=20):
super(LocationSensitiveAttention, self).__init__() super(LocationSensitiveAttention, self).__init__()
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.filters = filters self.filters = filters
padding = int((kernel_size - 1) / 2) padding = int((kernel_size - 1) / 2)
self.loc_conv = nn.Conv1d(2, filters, self.loc_conv = nn.Conv1d(
kernel_size=kernel_size, stride=1, 2,
padding=padding, bias=False) filters,
kernel_size=kernel_size,
stride=1,
padding=padding,
bias=False)
self.loc_linear = nn.Linear(filters, attn_dim) self.loc_linear = nn.Linear(filters, attn_dim)
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True) self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True) self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
@ -62,8 +71,9 @@ class LocationSensitiveAttention(nn.Module):
processed_loc = self.loc_linear(loc_conv) processed_loc = self.loc_linear(loc_conv)
processed_query = self.query_layer(query) processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annot) processed_annots = self.annot_layer(annot)
alignment = self.v(nn.functional.tanh( alignment = self.v(
processed_query + processed_annots + processed_loc)) nn.functional.tanh(processed_query + processed_annots +
processed_loc))
# (batch, max_time) # (batch, max_time)
return alignment.squeeze(-1) return alignment.squeeze(-1)
@ -85,16 +95,23 @@ class AttentionRNNCell(nn.Module):
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim) self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
# pick bahdanau or location sensitive attention # pick bahdanau or location sensitive attention
if align_model == 'b': if align_model == 'b':
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim, out_dim) self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
out_dim)
if align_model == 'ls': if align_model == 'ls':
self.alignment_model = LocationSensitiveAttention(annot_dim, rnn_dim, out_dim) self.alignment_model = LocationSensitiveAttention(
annot_dim, rnn_dim, out_dim)
else: else:
raise RuntimeError(" Wrong alignment model name: {}. Use\ raise RuntimeError(" Wrong alignment model name: {}. Use\
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model)) 'b' (Bahdanau) or 'ls' (Location Sensitive)."
.format(align_model))
def forward(self,
def forward(self, memory, context, rnn_state, annots, memory,
atten, annot_lens=None): context,
rnn_state,
annots,
atten,
annot_lens=None):
""" """
Shapes: Shapes:
- memory: (batch, 1, dim) or (batch, dim) - memory: (batch, 1, dim) or (batch, dim)

View File

@ -2,7 +2,6 @@
import torch import torch
from torch import nn from torch import nn
# class StopProjection(nn.Module): # class StopProjection(nn.Module):
# r""" Simple projection layer to predict the "stop token" # r""" Simple projection layer to predict the "stop token"

View File

@ -5,7 +5,6 @@ from utils.generic_utils import sequence_mask
class L1LossMasked(nn.Module): class L1LossMasked(nn.Module):
def __init__(self): def __init__(self):
super(L1LossMasked, self).__init__() super(L1LossMasked, self).__init__()
@ -31,21 +30,20 @@ class L1LossMasked(nn.Module):
# target_flat: (batch * max_len, dim) # target_flat: (batch * max_len, dim)
target_flat = target.view(-1, target.shape[-1]) target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim) # losses_flat: (batch * max_len, dim)
losses_flat = functional.l1_loss(input, target_flat, size_average=False, losses_flat = functional.l1_loss(
reduce=False) input, target_flat, size_average=False, reduce=False)
# losses: (batch, max_len, dim) # losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size()) losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1) # mask: (batch, max_len, 1)
mask = sequence_mask(sequence_length=length, mask = sequence_mask(
max_len=target.size(1)).unsqueeze(2) sequence_length=length, max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float() losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2])) loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
return loss return loss
class MSELossMasked(nn.Module): class MSELossMasked(nn.Module):
def __init__(self): def __init__(self):
super(MSELossMasked, self).__init__() super(MSELossMasked, self).__init__()
@ -71,14 +69,14 @@ class MSELossMasked(nn.Module):
# target_flat: (batch * max_len, dim) # target_flat: (batch * max_len, dim)
target_flat = target.view(-1, target.shape[-1]) target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim) # losses_flat: (batch * max_len, dim)
losses_flat = functional.mse_loss(input, target_flat, size_average=False, losses_flat = functional.mse_loss(
reduce=False) input, target_flat, size_average=False, reduce=False)
# losses: (batch, max_len, dim) # losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size()) losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1) # mask: (batch, max_len, 1)
mask = sequence_mask(sequence_length=length, mask = sequence_mask(
max_len=target.size(1)).unsqueeze(2) sequence_length=length, max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float() losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2])) loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
return loss return loss

View File

@ -3,6 +3,7 @@ import torch
from torch import nn from torch import nn
from .attention import AttentionRNNCell from .attention import AttentionRNNCell
class Prenet(nn.Module): class Prenet(nn.Module):
r""" Prenet as explained at https://arxiv.org/abs/1703.10135. r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
It creates as many layers as given by 'out_features' It creates as many layers as given by 'out_features'
@ -16,9 +17,10 @@ class Prenet(nn.Module):
def __init__(self, in_features, out_features=[256, 128]): def __init__(self, in_features, out_features=[256, 128]):
super(Prenet, self).__init__() super(Prenet, self).__init__()
in_features = [in_features] + out_features[:-1] in_features = [in_features] + out_features[:-1]
self.layers = nn.ModuleList( self.layers = nn.ModuleList([
[nn.Linear(in_size, out_size) nn.Linear(in_size, out_size)
for (in_size, out_size) in zip(in_features, out_features)]) for (in_size, out_size) in zip(in_features, out_features)
])
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
@ -46,12 +48,21 @@ class BatchNormConv1d(nn.Module):
- output: batch x dims - output: batch x dims
""" """
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
activation=None): activation=None):
super(BatchNormConv1d, self).__init__() super(BatchNormConv1d, self).__init__()
self.conv1d = nn.Conv1d(in_channels, out_channels, self.conv1d = nn.Conv1d(
kernel_size=kernel_size, in_channels,
stride=stride, padding=padding, bias=False) out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False)
# Following tensorflow's default parameters # Following tensorflow's default parameters
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3) self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
self.activation = activation self.activation = activation
@ -96,16 +107,25 @@ class CBHG(nn.Module):
- output: batch x time x dim*2 - output: batch x time x dim*2
""" """
def __init__(self, in_features, K=16, projections=[128, 128], num_highways=4): def __init__(self,
in_features,
K=16,
projections=[128, 128],
num_highways=4):
super(CBHG, self).__init__() super(CBHG, self).__init__()
self.in_features = in_features self.in_features = in_features
self.relu = nn.ReLU() self.relu = nn.ReLU()
# list of conv1d bank with filter size k=1...K # list of conv1d bank with filter size k=1...K
# TODO: try dilational layers instead # TODO: try dilational layers instead
self.conv1d_banks = nn.ModuleList( self.conv1d_banks = nn.ModuleList([
[BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1, BatchNormConv1d(
padding=k // 2, activation=self.relu) in_features,
for k in range(1, K + 1)]) in_features,
kernel_size=k,
stride=1,
padding=k // 2,
activation=self.relu) for k in range(1, K + 1)
])
# max pooling of conv bank # max pooling of conv bank
# TODO: try average pooling OR larger kernel size # TODO: try average pooling OR larger kernel size
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
@ -114,9 +134,15 @@ class CBHG(nn.Module):
activations += [None] activations += [None]
# setup conv1d projection layers # setup conv1d projection layers
layer_set = [] layer_set = []
for (in_size, out_size, ac) in zip(out_features, projections, activations): for (in_size, out_size, ac) in zip(out_features, projections,
layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, activations):
padding=1, activation=ac) layer = BatchNormConv1d(
in_size,
out_size,
kernel_size=3,
stride=1,
padding=1,
activation=ac)
layer_set.append(layer) layer_set.append(layer)
self.conv1d_projections = nn.ModuleList(layer_set) self.conv1d_projections = nn.ModuleList(layer_set)
# setup Highway layers # setup Highway layers
@ -204,10 +230,14 @@ class Decoder(nn.Module):
# memory -> |Prenet| -> processed_memory # memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, out_features=[256, 128]) self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = AttentionRNNCell(out_dim=128, rnn_dim=256, annot_dim=in_features, self.attention_rnn = AttentionRNNCell(
memory_dim=128, align_model='ls') out_dim=128,
rnn_dim=256,
annot_dim=in_features,
memory_dim=128,
align_model='ls')
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256+in_features, 256) self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state # decoder_RNN_input -> |RNN| -> RNN_state
self.decoder_rnns = nn.ModuleList( self.decoder_rnns = nn.ModuleList(
[nn.GRUCell(256, 256) for _ in range(2)]) [nn.GRUCell(256, 256) for _ in range(2)])
@ -241,17 +271,20 @@ class Decoder(nn.Module):
# Grouping multiple frames if necessary # Grouping multiple frames if necessary
if memory.size(-1) == self.memory_dim: if memory.size(-1) == self.memory_dim:
memory = memory.view(B, memory.size(1) // self.r, -1) memory = memory.view(B, memory.size(1) // self.r, -1)
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1), " !! Dimension mismatch {} vs {} * {}".format(
self.memory_dim, self.r) memory.size(-1), self.memory_dim, self.r)
T_decoder = memory.size(1) T_decoder = memory.size(1)
# go frame as zeros matrix # go frame as zeros matrix
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_() initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
# decoder states # decoder states
attention_rnn_hidden = inputs.data.new(B, 256).zero_() attention_rnn_hidden = inputs.data.new(B, 256).zero_()
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_() decoder_rnn_hiddens = [
for _ in range(len(self.decoder_rnns))] inputs.data.new(B, 256).zero_()
for _ in range(len(self.decoder_rnns))
]
current_context_vec = inputs.data.new(B, self.in_features).zero_() current_context_vec = inputs.data.new(B, self.in_features).zero_()
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_() stopnet_rnn_hidden = inputs.data.new(B,
self.r * self.memory_dim).zero_()
# attention states # attention states
attention = inputs.data.new(B, T).zero_() attention = inputs.data.new(B, T).zero_()
attention_cum = inputs.data.new(B, T).zero_() attention_cum = inputs.data.new(B, T).zero_()
@ -268,13 +301,12 @@ class Decoder(nn.Module):
if greedy: if greedy:
memory_input = outputs[-1] memory_input = outputs[-1]
else: else:
memory_input = memory[t-1] memory_input = memory[t - 1]
# Prenet # Prenet
processed_memory = self.prenet(memory_input) processed_memory = self.prenet(memory_input)
# Attention RNN # Attention RNN
attention_cat = torch.cat((attention.unsqueeze(1), attention_cat = torch.cat(
attention_cum.unsqueeze(1)), (attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention_cat, input_lens) inputs, attention_cat, input_lens)
@ -293,16 +325,18 @@ class Decoder(nn.Module):
output = self.proj_to_mel(decoder_output) output = self.proj_to_mel(decoder_output)
stop_input = output stop_input = output
# predict stop token # predict stop token
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden) stop_token, stopnet_rnn_hidden = self.stopnet(
stop_input, stopnet_rnn_hidden)
outputs += [output] outputs += [output]
attentions += [attention] attentions += [attention]
stop_tokens += [stop_token] stop_tokens += [stop_token]
t += 1 t += 1
if (not greedy and self.training) or (greedy and memory is not None): if (not greedy and self.training) or (greedy
and memory is not None):
if t >= T_decoder: if t >= T_decoder:
break break
else: else:
if t > inputs.shape[1]/2 and stop_token > 0.6: if t > inputs.shape[1] / 2 and stop_token > 0.6:
break break
elif t > self.max_decoder_steps: elif t > self.max_decoder_steps:
print(" | | > Decoder stopped with 'max_decoder_steps") print(" | | > Decoder stopped with 'max_decoder_steps")

View File

@ -6,14 +6,18 @@ from layers.tacotron import Prenet, Encoder, Decoder, CBHG
class Tacotron(nn.Module): class Tacotron(nn.Module):
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80, def __init__(self,
r=5, padding_idx=None): embedding_dim=256,
linear_dim=1025,
mel_dim=80,
r=5,
padding_idx=None):
super(Tacotron, self).__init__() super(Tacotron, self).__init__()
self.r = r self.r = r
self.mel_dim = mel_dim self.mel_dim = mel_dim
self.linear_dim = linear_dim self.linear_dim = linear_dim
self.embedding = nn.Embedding(len(symbols), embedding_dim, self.embedding = nn.Embedding(
padding_idx=padding_idx) len(symbols), embedding_dim, padding_idx=padding_idx)
print(" | > Number of characters : {}".format(len(symbols))) print(" | > Number of characters : {}".format(len(symbols)))
self.embedding.weight.data.normal_(0, 0.3) self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(embedding_dim) self.encoder = Encoder(embedding_dim)

View File

@ -40,8 +40,12 @@ def visualize(alignment, spectrogram, stop_tokens, CONFIG):
plt.plot(range(len(stop_tokens)), list(stop_tokens)) plt.plot(range(len(stop_tokens)), list(stop_tokens))
plt.subplot(3, 1, 3) plt.subplot(3, 1, 3)
librosa.display.specshow(spectrogram.T, sr=CONFIG.sample_rate, librosa.display.specshow(
hop_length=hop_length, x_axis="time", y_axis="linear") spectrogram.T,
sr=CONFIG.sample_rate,
hop_length=hop_length,
x_axis="time",
y_axis="linear")
plt.xlabel("Time", fontsize=label_fontsize) plt.xlabel("Time", fontsize=label_fontsize)
plt.ylabel("Hz", fontsize=label_fontsize) plt.ylabel("Hz", fontsize=label_fontsize)
plt.tight_layout() plt.tight_layout()

View File

@ -1,9 +1,9 @@
## TTS example web-server ## TTS example web-server
Steps to run: Steps to run:
1. Download one of the models given on the main page. 1. Download one of the models given on the main page. Click [here](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing) for the lastest model.
2. Checkout the corresponding commit history. 2. Checkout the corresponding commit history or use ```server``` branch if you like to use the latest model.
2. Set paths and other options in the file ```server/conf.json```. 2. Set the paths and the other options in the file ```server/conf.json```.
3. Run the server ```python server/server.py -c conf.json```. (Requires Flask) 3. Run the server ```python server/server.py -c server/conf.json```. (Requires Flask)
4. Go to ```localhost:[given_port]``` and enjoy. 4. Go to ```localhost:[given_port]``` and enjoy.
Note that the audio quality on browser is slightly worse due to the encoder quantization. For high quality results, please use the library versions shown in the ```requirements.txt``` file.

View File

@ -1,5 +1,5 @@
{ {
"model_path":"/home/erogol/projects/models/LJSpeech/May-22-2018_03_24PM-e6112f7", "model_path":"../models/May-22-2018_03_24PM-e6112f7",
"model_name":"checkpoint_272976.pth.tar", "model_name":"checkpoint_272976.pth.tar",
"model_config":"config.json", "model_config":"config.json",
"port": 5002, "port": 5002,

View File

@ -2,12 +2,11 @@
import argparse import argparse
from synthesizer import Synthesizer from synthesizer import Synthesizer
from TTS.utils.generic_utils import load_config from TTS.utils.generic_utils import load_config
from flask import (Flask, Response, request, from flask import Flask, Response, request, render_template, send_file
render_template, send_file)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config_path', type=str, parser.add_argument(
help='path to config file for training') '-c', '--config_path', type=str, help='path to config file for training')
args = parser.parse_args() args = parser.parse_args()
config = load_config(args.config_path) config = load_config(args.config_path)
@ -16,17 +15,19 @@ synthesizer = Synthesizer()
synthesizer.load_model(config.model_path, config.model_name, synthesizer.load_model(config.model_path, config.model_name,
config.model_config, config.use_cuda) config.model_config, config.use_cuda)
@app.route('/') @app.route('/')
def index(): def index():
return render_template('index.html') return render_template('index.html')
@app.route('/api/tts', methods=['GET']) @app.route('/api/tts', methods=['GET'])
def tts(): def tts():
text = request.args.get('text') text = request.args.get('text')
print(" > Model input: {}".format(text)) print(" > Model input: {}".format(text))
data = synthesizer.tts(text) data = synthesizer.tts(text)
return send_file(data, return send_file(data, mimetype='audio/wav')
mimetype='audio/wav')
if __name__ == '__main__': if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=config.port) app.run(debug=True, host='0.0.0.0', port=config.port)

View File

@ -13,7 +13,6 @@ from matplotlib import pylab as plt
class Synthesizer(object): class Synthesizer(object):
def load_model(self, model_path, model_name, model_config, use_cuda): def load_model(self, model_path, model_name, model_config, use_cuda):
model_config = os.path.join(model_path, model_config) model_config = os.path.join(model_path, model_config)
self.model_file = os.path.join(model_path, model_name) self.model_file = os.path.join(model_path, model_name)
@ -23,15 +22,25 @@ class Synthesizer(object):
config = load_config(model_config) config = load_config(model_config)
self.config = config self.config = config
self.use_cuda = use_cuda self.use_cuda = use_cuda
self.model = Tacotron(config.embedding_size, config.num_freq, config.num_mels, config.r) self.model = Tacotron(config.embedding_size, config.num_freq,
self.ap = AudioProcessor(config.sample_rate, config.num_mels, config.min_level_db, config.num_mels, config.r)
config.frame_shift_ms, config.frame_length_ms, config.preemphasis, self.ap = AudioProcessor(
config.ref_level_db, config.num_freq, config.power, griffin_lim_iters=60) config.sample_rate,
config.num_mels,
config.min_level_db,
config.frame_shift_ms,
config.frame_length_ms,
config.preemphasis,
config.ref_level_db,
config.num_freq,
config.power,
griffin_lim_iters=60)
# load model state # load model state
if use_cuda: if use_cuda:
cp = torch.load(self.model_file) cp = torch.load(self.model_file)
else: else:
cp = torch.load(self.model_file, map_location=lambda storage, loc: storage) cp = torch.load(
self.model_file, map_location=lambda storage, loc: storage)
# load the model # load the model
self.model.load_state_dict(cp['model']) self.model.load_state_dict(cp['model'])
if use_cuda: if use_cuda:
@ -40,12 +49,8 @@ class Synthesizer(object):
def save_wav(self, wav, path): def save_wav(self, wav, path):
wav *= 32767 / max(1e-8, np.max(np.abs(wav))) wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
# sf.write(path, wav.astype(np.int32), self.config.sample_rate, format='wav') librosa.output.write_wav(path, wav.astype(np.int16),
# wav = librosa.util.normalize(wav.astype(np.float), norm=np.inf, axis=None) self.config.sample_rate)
# wav = wav / wav.max()
# sf.write(path, wav.astype('float'), self.config.sample_rate, format='ogg')
scipy.io.wavfile.write(path, self.config.sample_rate, wav.astype(np.int16))
# librosa.output.write_wav(path, wav.astype(np.int16), self.config.sample_rate, norm=True)
def tts(self, text): def tts(self, text):
text_cleaner = [self.config.text_cleaner] text_cleaner = [self.config.text_cleaner]
@ -54,14 +59,15 @@ class Synthesizer(object):
if len(sen) < 3: if len(sen) < 3:
continue continue
sen = sen.strip() sen = sen.strip()
sen +='.' sen += '.'
print(sen) print(sen)
sen = sen.strip() sen = sen.strip()
seq = np.array(text_to_sequence(text, text_cleaner)) seq = np.array(text_to_sequence(text, text_cleaner))
chars_var = torch.from_numpy(seq).unsqueeze(0) chars_var = torch.from_numpy(seq).unsqueeze(0).long()
if self.use_cuda: if self.use_cuda:
chars_var = chars_var.cuda() chars_var = chars_var.cuda()
mel_out, linear_out, alignments, stop_tokens = self.model.forward(chars_var) mel_out, linear_out, alignments, stop_tokens = self.model.forward(
chars_var)
linear_out = linear_out[0].data.cpu().numpy() linear_out = linear_out[0].data.cpu().numpy()
wav = self.ap.inv_spectrogram(linear_out.T) wav = self.ap.inv_spectrogram(linear_out.T)
# wav = wav[:self.ap.find_endpoint(wav)] # wav = wav[:self.ap.find_endpoint(wav)]

View File

@ -25,7 +25,6 @@ else:
class build_py(setuptools.command.build_py.build_py): class build_py(setuptools.command.build_py.build_py):
def run(self): def run(self):
self.create_version_file() self.create_version_file()
setuptools.command.build_py.build_py.run(self) setuptools.command.build_py.build_py.run(self)
@ -40,7 +39,6 @@ class build_py(setuptools.command.build_py.build_py):
class develop(setuptools.command.develop.develop): class develop(setuptools.command.develop.develop):
def run(self): def run(self):
build_py.create_version_file() build_py.create_version_file()
setuptools.command.develop.develop.run(self) setuptools.command.develop.develop.run(self)
@ -50,8 +48,11 @@ def create_readme_rst():
global cwd global cwd
try: try:
subprocess.check_call( subprocess.check_call(
["pandoc", "--from=markdown", "--to=rst", "--output=README.rst", [
"README.md"], cwd=cwd) "pandoc", "--from=markdown", "--to=rst", "--output=README.rst",
"README.md"
],
cwd=cwd)
print("Generated README.rst from README.md using pandoc.") print("Generated README.rst from README.md using pandoc.")
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
pass pass
@ -59,33 +60,31 @@ def create_readme_rst():
pass pass
setup(name='TTS', setup(
version=version, name='TTS',
url='https://github.com/mozilla/TTS', version=version,
description='Text to Speech with Deep Learning', url='https://github.com/mozilla/TTS',
description='Text to Speech with Deep Learning',
packages=find_packages(), packages=find_packages(),
cmdclass={ cmdclass={
'build_py': build_py, 'build_py': build_py,
'develop': develop, 'develop': develop,
}, },
setup_requires=[ setup_requires=["numpy"],
"numpy" install_requires=[
], "scipy",
install_requires=[ "torch == 0.4.0",
"scipy", "librosa",
"torch == 0.4.0", "unidecode",
"librosa", "tensorboardX",
"unidecode", "matplotlib",
"tensorboardX", "Pillow",
"matplotlib", "flask",
"Pillow", "lws",
"flask", ],
"lws", extras_require={
], "bin": [
extras_require={ "tqdm",
"bin": [ "requests",
"tqdm", ],
"requests", })
],
})

View File

@ -8,19 +8,17 @@ OUT_PATH = '/tmp/test.pth.tar'
class ModelSavingTests(unittest.TestCase): class ModelSavingTests(unittest.TestCase):
def save_checkpoint_test(self): def save_checkpoint_test(self):
# create a dummy model # create a dummy model
model = Prenet(128, out_features=[256, 128]) model = Prenet(128, out_features=[256, 128])
model = T.nn.DataParallel(layer) model = T.nn.DataParallel(layer)
# save the model # save the model
save_checkpoint(model, None, 100, save_checkpoint(model, None, 100, OUTPATH, 1, 1)
OUTPATH, 1, 1)
# load the model to CPU # load the model to CPU
model_dict = torch.load(MODEL_PATH, map_location=lambda storage, model_dict = torch.load(
loc: storage) MODEL_PATH, map_location=lambda storage, loc: storage)
model.load_state_dict(model_dict['model']) model.load_state_dict(model_dict['model'])
def save_best_model_test(self): def save_best_model_test(self):
@ -29,11 +27,9 @@ class ModelSavingTests(unittest.TestCase):
model = T.nn.DataParallel(layer) model = T.nn.DataParallel(layer)
# save the model # save the model
best_loss = save_best_model(model, None, 0, best_loss = save_best_model(model, None, 0, 100, OUT_PATH, 10, 1)
100, OUT_PATH,
10, 1)
# load the model to CPU # load the model to CPU
model_dict = torch.load(MODEL_PATH, map_location=lambda storage, model_dict = torch.load(
loc: storage) MODEL_PATH, map_location=lambda storage, loc: storage)
model.load_state_dict(model_dict['model']) model.load_state_dict(model_dict['model'])

View File

@ -7,7 +7,6 @@ from TTS.utils.generic_utils import sequence_mask
class PrenetTests(unittest.TestCase): class PrenetTests(unittest.TestCase):
def test_in_out(self): def test_in_out(self):
layer = Prenet(128, out_features=[256, 128]) layer = Prenet(128, out_features=[256, 128])
dummy_input = T.rand(4, 128) dummy_input = T.rand(4, 128)
@ -19,7 +18,6 @@ class PrenetTests(unittest.TestCase):
class CBHGTests(unittest.TestCase): class CBHGTests(unittest.TestCase):
def test_in_out(self): def test_in_out(self):
layer = CBHG(128, K=6, projections=[128, 128], num_highways=2) layer = CBHG(128, K=6, projections=[128, 128], num_highways=2)
dummy_input = T.rand(4, 8, 128) dummy_input = T.rand(4, 8, 128)
@ -32,7 +30,6 @@ class CBHGTests(unittest.TestCase):
class DecoderTests(unittest.TestCase): class DecoderTests(unittest.TestCase):
def test_in_out(self): def test_in_out(self):
layer = Decoder(in_features=256, memory_dim=80, r=2) layer = Decoder(in_features=256, memory_dim=80, r=2)
dummy_input = T.rand(4, 8, 256) dummy_input = T.rand(4, 8, 256)
@ -49,7 +46,6 @@ class DecoderTests(unittest.TestCase):
class EncoderTests(unittest.TestCase): class EncoderTests(unittest.TestCase):
def test_in_out(self): def test_in_out(self):
layer = Encoder(128) layer = Encoder(128)
dummy_input = T.rand(4, 8, 128) dummy_input = T.rand(4, 8, 128)
@ -63,7 +59,6 @@ class EncoderTests(unittest.TestCase):
class L1LossMaskedTests(unittest.TestCase): class L1LossMaskedTests(unittest.TestCase):
def test_in_out(self): def test_in_out(self):
layer = L1LossMasked() layer = L1LossMasked()
dummy_input = T.ones(4, 8, 128).float() dummy_input = T.ones(4, 8, 128).float()
@ -80,7 +75,7 @@ class L1LossMaskedTests(unittest.TestCase):
dummy_input = T.ones(4, 8, 128).float() dummy_input = T.ones(4, 8, 128).float()
dummy_target = T.zeros(4, 8, 128).float() dummy_target = T.zeros(4, 8, 128).float()
dummy_length = (T.arange(5, 9)).long() dummy_length = (T.arange(5, 9)).long()
mask = ((sequence_mask(dummy_length).float() - 1.0) mask = (
* 100.0).unsqueeze(2) (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length) output = layer(dummy_input + mask, dummy_target, dummy_length)
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0]) assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])

View File

@ -12,34 +12,38 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
class TestLJSpeechDataset(unittest.TestCase): class TestLJSpeechDataset(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TestLJSpeechDataset, self).__init__(*args, **kwargs) super(TestLJSpeechDataset, self).__init__(*args, **kwargs)
self.max_loader_iter = 4 self.max_loader_iter = 4
self.ap = AudioProcessor(sample_rate=c.sample_rate, self.ap = AudioProcessor(
num_mels=c.num_mels, sample_rate=c.sample_rate,
min_level_db=c.min_level_db, num_mels=c.num_mels,
frame_shift_ms=c.frame_shift_ms, min_level_db=c.min_level_db,
frame_length_ms=c.frame_length_ms, frame_shift_ms=c.frame_shift_ms,
ref_level_db=c.ref_level_db, frame_length_ms=c.frame_length_ms,
num_freq=c.num_freq, ref_level_db=c.ref_level_db,
power=c.power, num_freq=c.num_freq,
preemphasis=c.preemphasis, power=c.power,
min_mel_freq=c.min_mel_freq, preemphasis=c.preemphasis,
max_mel_freq=c.max_mel_freq) min_mel_freq=c.min_mel_freq,
max_mel_freq=c.max_mel_freq)
def test_loader(self): def test_loader(self):
dataset = LJSpeech.MyDataset(os.path.join(c.data_path_LJSpeech), dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech, 'metadata.csv'), os.path.join(c.data_path_LJSpeech),
c.r, os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
c.text_cleaner, c.r,
ap = self.ap, c.text_cleaner,
min_seq_len=c.min_seq_len ap=self.ap,
) min_seq_len=c.min_seq_len)
dataloader = DataLoader(dataset, batch_size=2, dataloader = DataLoader(
shuffle=True, collate_fn=dataset.collate_fn, dataset,
drop_last=True, num_workers=c.num_loader_workers) batch_size=2,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
@ -62,18 +66,22 @@ class TestLJSpeechDataset(unittest.TestCase):
assert mel_input.shape[2] == c.num_mels assert mel_input.shape[2] == c.num_mels
def test_padding(self): def test_padding(self):
dataset = LJSpeech.MyDataset(os.path.join(c.data_path_LJSpeech), dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech, 'metadata.csv'), os.path.join(c.data_path_LJSpeech),
1, os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
c.text_cleaner, 1,
ap = self.ap, c.text_cleaner,
min_seq_len=c.min_seq_len ap=self.ap,
) min_seq_len=c.min_seq_len)
# Test for batch size 1 # Test for batch size 1
dataloader = DataLoader(dataset, batch_size=1, dataloader = DataLoader(
shuffle=False, collate_fn=dataset.collate_fn, dataset,
drop_last=True, num_workers=c.num_loader_workers) batch_size=1,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
@ -98,9 +106,13 @@ class TestLJSpeechDataset(unittest.TestCase):
assert mel_lengths[0] == mel_input[0].shape[0] assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2 # Test for batch size 2
dataloader = DataLoader(dataset, batch_size=2, dataloader = DataLoader(
shuffle=False, collate_fn=dataset.collate_fn, dataset,
drop_last=False, num_workers=c.num_loader_workers) batch_size=2,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
@ -130,9 +142,9 @@ class TestLJSpeechDataset(unittest.TestCase):
assert mel_lengths[idx] == mel_input[idx].shape[0] assert mel_lengths[idx] == mel_input[idx].shape[0]
# check the second itme in the batch # check the second itme in the batch
assert mel_input[1-idx, -1].sum() == 0 assert mel_input[1 - idx, -1].sum() == 0
assert linear_input[1-idx, -1].sum() == 0 assert linear_input[1 - idx, -1].sum() == 0
assert stop_target[1-idx, -1] == 1 assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1 assert len(mel_lengths.shape) == 1
# check batch conditions # check batch conditions
@ -141,34 +153,38 @@ class TestLJSpeechDataset(unittest.TestCase):
class TestKusalDataset(unittest.TestCase): class TestKusalDataset(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TestKusalDataset, self).__init__(*args, **kwargs) super(TestKusalDataset, self).__init__(*args, **kwargs)
self.max_loader_iter = 4 self.max_loader_iter = 4
self.ap = AudioProcessor(sample_rate=c.sample_rate, self.ap = AudioProcessor(
num_mels=c.num_mels, sample_rate=c.sample_rate,
min_level_db=c.min_level_db, num_mels=c.num_mels,
frame_shift_ms=c.frame_shift_ms, min_level_db=c.min_level_db,
frame_length_ms=c.frame_length_ms, frame_shift_ms=c.frame_shift_ms,
ref_level_db=c.ref_level_db, frame_length_ms=c.frame_length_ms,
num_freq=c.num_freq, ref_level_db=c.ref_level_db,
power=c.power, num_freq=c.num_freq,
preemphasis=c.preemphasis, power=c.power,
min_mel_freq=c.min_mel_freq, preemphasis=c.preemphasis,
max_mel_freq=c.max_mel_freq) min_mel_freq=c.min_mel_freq,
max_mel_freq=c.max_mel_freq)
def test_loader(self): def test_loader(self):
dataset = Kusal.MyDataset(os.path.join(c.data_path_Kusal), dataset = Kusal.MyDataset(
os.path.join(c.data_path_Kusal, 'prompts.txt'), os.path.join(c.data_path_Kusal),
c.r, os.path.join(c.data_path_Kusal, 'prompts.txt'),
c.text_cleaner, c.r,
ap = self.ap, c.text_cleaner,
min_seq_len=c.min_seq_len ap=self.ap,
) min_seq_len=c.min_seq_len)
dataloader = DataLoader(dataset, batch_size=2, dataloader = DataLoader(
shuffle=True, collate_fn=dataset.collate_fn, dataset,
drop_last=True, num_workers=c.num_loader_workers) batch_size=2,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
@ -191,18 +207,22 @@ class TestKusalDataset(unittest.TestCase):
assert mel_input.shape[2] == c.num_mels assert mel_input.shape[2] == c.num_mels
def test_padding(self): def test_padding(self):
dataset = Kusal.MyDataset(os.path.join(c.data_path_Kusal), dataset = Kusal.MyDataset(
os.path.join(c.data_path_Kusal, 'prompts.txt'), os.path.join(c.data_path_Kusal),
1, os.path.join(c.data_path_Kusal, 'prompts.txt'),
c.text_cleaner, 1,
ap = self.ap, c.text_cleaner,
min_seq_len=c.min_seq_len ap=self.ap,
) min_seq_len=c.min_seq_len)
# Test for batch size 1 # Test for batch size 1
dataloader = DataLoader(dataset, batch_size=1, dataloader = DataLoader(
shuffle=False, collate_fn=dataset.collate_fn, dataset,
drop_last=True, num_workers=c.num_loader_workers) batch_size=1,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
@ -227,9 +247,13 @@ class TestKusalDataset(unittest.TestCase):
assert mel_lengths[0] == mel_input[0].shape[0] assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2 # Test for batch size 2
dataloader = DataLoader(dataset, batch_size=2, dataloader = DataLoader(
shuffle=False, collate_fn=dataset.collate_fn, dataset,
drop_last=False, num_workers=c.num_loader_workers) batch_size=2,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
@ -259,9 +283,9 @@ class TestKusalDataset(unittest.TestCase):
assert mel_lengths[idx] == mel_input[idx].shape[0] assert mel_lengths[idx] == mel_input[idx].shape[0]
# check the second itme in the batch # check the second itme in the batch
assert mel_input[1-idx, -1].sum() == 0 assert mel_input[1 - idx, -1].sum() == 0
assert linear_input[1-idx, -1].sum() == 0 assert linear_input[1 - idx, -1].sum() == 0
assert stop_target[1-idx, -1] == 1 assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1 assert len(mel_lengths.shape) == 1
# check batch conditions # check batch conditions

View File

@ -19,50 +19,51 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
class TacotronTrainTest(unittest.TestCase): class TacotronTrainTest(unittest.TestCase):
def test_train_step(self): def test_train_step(self):
input = torch.randint(0, 24, (8, 128)).long().to(device) input = torch.randint(0, 24, (8, 128)).long().to(device)
mel_spec = torch.rand(8, 30, c.num_mels).to(device) mel_spec = torch.rand(8, 30, c.num_mels).to(device)
linear_spec = torch.rand(8, 30, c.num_freq).to(device) linear_spec = torch.rand(8, 30, c.num_freq).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device) mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
stop_targets = torch.zeros(8, 30, 1).float().to(device) stop_targets = torch.zeros(8, 30, 1).float().to(device)
for idx in mel_lengths: for idx in mel_lengths:
stop_targets[:, int(idx.item()):, 0] = 1.0 stop_targets[:, int(idx.item()):, 0] = 1.0
stop_targets = stop_targets.view(input.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = stop_targets.view(input.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
criterion = L1LossMasked().to(device) criterion = L1LossMasked().to(device)
criterion_st = nn.BCELoss().to(device) criterion_st = nn.BCELoss().to(device)
model = Tacotron(c.embedding_size, model = Tacotron(c.embedding_size, c.num_freq, c.num_mels,
c.num_freq,
c.num_mels,
c.r).to(device) c.r).to(device)
model.train() model.train()
model_ref = copy.deepcopy(model) model_ref = copy.deepcopy(model)
count = 0 count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()): for param, param_ref in zip(model.parameters(),
model_ref.parameters()):
assert (param - param_ref).sum() == 0, param assert (param - param_ref).sum() == 0, param
count += 1 count += 1
optimizer = optim.Adam(model.parameters(), lr=c.lr) optimizer = optim.Adam(model.parameters(), lr=c.lr)
for i in range(5): for i in range(5):
mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec) mel_out, linear_out, align, stop_tokens = model.forward(
input, mel_spec)
assert stop_tokens.data.max() <= 1.0 assert stop_tokens.data.max() <= 1.0
assert stop_tokens.data.min() >= 0.0 assert stop_tokens.data.min() >= 0.0
optimizer.zero_grad() optimizer.zero_grad()
loss = criterion(mel_out, mel_spec, mel_lengths) loss = criterion(mel_out, mel_spec, mel_lengths)
stop_loss = criterion_st(stop_tokens, stop_targets) stop_loss = criterion_st(stop_tokens, stop_targets)
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss loss = loss + criterion(linear_out, linear_spec,
mel_lengths) + stop_loss
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# check parameter changes # check parameter changes
count = 0 count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()): for param, param_ref in zip(model.parameters(),
model_ref.parameters()):
# ignore pre-higway layer since it works conditional # ignore pre-higway layer since it works conditional
if count not in [148, 59]: if count not in [148, 59]:
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref) assert (param != param_ref).any(
), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref)
count += 1 count += 1

226
train.py
View File

@ -1,39 +1,34 @@
import os import os
import sys import sys
import time import time
import datetime
import shutil import shutil
import torch import torch
import signal
import argparse import argparse
import importlib import importlib
import pickle
import traceback import traceback
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from torch import optim from torch import optim
from torch import onnx
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from utils.generic_utils import (synthesis, remove_experiment_folder, from utils.generic_utils import (
create_experiment_folder, save_checkpoint, synthesis, remove_experiment_folder, create_experiment_folder,
save_best_model, load_config, lr_decay, save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
count_parameters, check_update, get_commit_hash) check_update, get_commit_hash)
from utils.visual import plot_alignment, plot_spectrogram from utils.visual import plot_alignment, plot_spectrogram
from models.tacotron import Tacotron from models.tacotron import Tacotron
from layers.losses import L1LossMasked from layers.losses import L1LossMasked
from utils.audio import AudioProcessor from utils.audio import AudioProcessor
torch.manual_seed(1) torch.manual_seed(1)
torch.set_num_threads(4) torch.set_num_threads(4)
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, ap, epoch): def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
ap, epoch):
model = model.train() model = model.train()
epoch_time = 0 epoch_time = 0
avg_linear_loss = 0 avg_linear_loss = 0
@ -54,7 +49,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
stop_targets = data[5] stop_targets = data[5]
# set stop targets view, we predict a single stop token per r frames prediction # set stop targets view, we predict a single stop token per r frames prediction
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
current_step = num_iter + args.restore_step + \ current_step = num_iter + args.restore_step + \
@ -89,7 +85,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# loss computation # loss computation
stop_loss = criterion_st(stop_tokens, stop_targets) stop_loss = criterion_st(stop_tokens, stop_targets)
mel_loss = criterion(mel_output, mel_input, mel_lengths) mel_loss = criterion(mel_output, mel_input, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \ linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths)\
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq], + 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_input[:, :, :n_priority_freq], linear_input[:, :, :n_priority_freq],
mel_lengths) mel_lengths)
@ -106,7 +102,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# backpass and check the grad norm for stop loss # backpass and check the grad norm for stop loss
stop_loss.backward() stop_loss.backward()
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, 0.5, 100) grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet,
0.5, 100)
if skip_flag: if skip_flag:
optimizer_st.zero_grad() optimizer_st.zero_grad()
print(" | | > Iteration skipped fro stopnet!!") print(" | | > Iteration skipped fro stopnet!!")
@ -117,9 +114,10 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
epoch_time += step_time epoch_time += step_time
if current_step % c.print_step == 0: if current_step % c.print_step == 0:
print(" | | > Step:{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "\ print(" | | > Step:{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "\ "MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
"GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter, current_step, "GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter,
current_step,
loss.item(), loss.item(),
linear_loss.item(), linear_loss.item(),
mel_loss.item(), mel_loss.item(),
@ -147,8 +145,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
if current_step % c.save_step == 0: if current_step % c.save_step == 0:
if c.checkpoint: if c.checkpoint:
# save model # save model
save_checkpoint(model, optimizer, optimizer_st, linear_loss.item(), save_checkpoint(model, optimizer, optimizer_st,
OUT_PATH, current_step, epoch) linear_loss.item(), OUT_PATH, current_step,
epoch)
# Diagnostic visualizations # Diagnostic visualizations
const_spec = linear_output[0].data.cpu().numpy() const_spec = linear_output[0].data.cpu().numpy()
@ -168,8 +167,11 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
ap.griffin_lim_iters = 60 ap.griffin_lim_iters = 60
audio_signal = ap.inv_spectrogram(audio_signal.T) audio_signal = ap.inv_spectrogram(audio_signal.T)
try: try:
tb.add_audio('SampleAudio', audio_signal, current_step, tb.add_audio(
sample_rate=c.sample_rate) 'SampleAudio',
audio_signal,
current_step,
sample_rate=c.sample_rate)
except: except:
pass pass
@ -180,9 +182,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
avg_step_time /= (num_iter + 1) avg_step_time /= (num_iter + 1)
# print epoch stats # print epoch stats
print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "\ print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "\ "AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "\ "AvgStopLoss:{:.5f} EpochTime:{:.2f} "
"AvgStepTime:{:.2f}".format(current_step, "AvgStepTime:{:.2f}".format(current_step,
avg_total_loss, avg_total_loss,
avg_linear_loss, avg_linear_loss,
@ -209,10 +211,12 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
avg_mel_loss = 0 avg_mel_loss = 0
avg_stop_loss = 0 avg_stop_loss = 0
print(" | > Validation") print(" | > Validation")
test_sentences = ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", test_sentences = [
"Be a voice, not an echo.", "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"I'm sorry Dave. I'm afraid I can't do that.", "Be a voice, not an echo.",
"This cake is great. It's so delicious and moist."] "I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist."
]
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
with torch.no_grad(): with torch.no_grad():
if data_loader is not None: if data_loader is not None:
@ -228,7 +232,9 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
stop_targets = data[5] stop_targets = data[5]
# set stop targets view, we predict a single stop token per r frames prediction # set stop targets view, we predict a single stop token per r frames prediction
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r,
-1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
# dispatch data to GPU # dispatch data to GPU
@ -256,11 +262,11 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
epoch_time += step_time epoch_time += step_time
if num_iter % c.print_step == 0: if num_iter % c.print_step == 0:
print(" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "\ print(
"StopLoss: {:.5f} ".format(loss.item(), " | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
linear_loss.item(), "StopLoss: {:.5f} ".format(loss.item(), linear_loss.item(),
mel_loss.item(), mel_loss.item(), stop_loss.item()),
stop_loss.item()), flush=True) flush=True)
avg_linear_loss += linear_loss.item() avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item() avg_mel_loss += mel_loss.item()
@ -278,15 +284,19 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
tb.add_image('ValVisual/Reconstruction', const_spec, current_step) tb.add_image('ValVisual/Reconstruction', const_spec, current_step)
tb.add_image('ValVisual/GroundTruth', gt_spec, current_step) tb.add_image('ValVisual/GroundTruth', gt_spec, current_step)
tb.add_image('ValVisual/ValidationAlignment', align_img, current_step) tb.add_image('ValVisual/ValidationAlignment', align_img,
current_step)
# Sample audio # Sample audio
audio_signal = linear_output[idx].data.cpu().numpy() audio_signal = linear_output[idx].data.cpu().numpy()
ap.griffin_lim_iters = 60 ap.griffin_lim_iters = 60
audio_signal = ap.inv_spectrogram(audio_signal.T) audio_signal = ap.inv_spectrogram(audio_signal.T)
try: try:
tb.add_audio('ValSampleAudio', audio_signal, current_step, tb.add_audio(
sample_rate=c.sample_rate) 'ValSampleAudio',
audio_signal,
current_step,
sample_rate=c.sample_rate)
except: except:
# sometimes audio signal is out of boundaries # sometimes audio signal is out of boundaries
pass pass
@ -298,81 +308,88 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
# Plot Learning Stats # Plot Learning Stats
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step) tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss,
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step) current_step)
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss,
current_step)
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step) tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step) tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss,
current_step)
# test sentences # test sentences
ap.griffin_lim_iters = 60 ap.griffin_lim_iters = 60
for idx, test_sentence in enumerate(test_sentences): for idx, test_sentence in enumerate(test_sentences):
wav, linear_spec, alignments = synthesis(model, ap, test_sentence, use_cuda, wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
c.text_cleaner) use_cuda, c.text_cleaner)
try: try:
wav_name = 'TestSentences/{}'.format(idx) wav_name = 'TestSentences/{}'.format(idx)
tb.add_audio(wav_name, wav, current_step, tb.add_audio(
sample_rate=c.sample_rate) wav_name, wav, current_step, sample_rate=c.sample_rate)
except: except:
pass pass
align_img = alignments[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy()
linear_spec = plot_spectrogram(linear_spec, ap) linear_spec = plot_spectrogram(linear_spec, ap)
align_img = plot_alignment(align_img) align_img = plot_alignment(align_img)
tb.add_image('TestSentences/{}_Spectrogram'.format(idx), linear_spec, current_step) tb.add_image('TestSentences/{}_Spectrogram'.format(idx), linear_spec,
tb.add_image('TestSentences/{}_Alignment'.format(idx), align_img, current_step) current_step)
tb.add_image('TestSentences/{}_Alignment'.format(idx), align_img,
current_step)
return avg_linear_loss return avg_linear_loss
def main(args): def main(args):
dataset = importlib.import_module('datasets.'+c.dataset) dataset = importlib.import_module('datasets.' + c.dataset)
Dataset = getattr(dataset, 'MyDataset') Dataset = getattr(dataset, 'MyDataset')
audio = importlib.import_module('utils.'+c.audio_processor) audio = importlib.import_module('utils.' + c.audio_processor)
AudioProcessor = getattr(audio, 'AudioProcessor') AudioProcessor = getattr(audio, 'AudioProcessor')
ap = AudioProcessor(sample_rate=c.sample_rate, ap = AudioProcessor(
num_mels=c.num_mels, sample_rate=c.sample_rate,
min_level_db=c.min_level_db, num_mels=c.num_mels,
frame_shift_ms=c.frame_shift_ms, min_level_db=c.min_level_db,
frame_length_ms=c.frame_length_ms, frame_shift_ms=c.frame_shift_ms,
ref_level_db=c.ref_level_db, frame_length_ms=c.frame_length_ms,
num_freq=c.num_freq, ref_level_db=c.ref_level_db,
power=c.power, num_freq=c.num_freq,
preemphasis=c.preemphasis, power=c.power,
min_mel_freq=c.min_mel_freq, preemphasis=c.preemphasis,
max_mel_freq=c.max_mel_freq) min_mel_freq=c.min_mel_freq,
max_mel_freq=c.max_mel_freq)
# Setup the dataset # Setup the dataset
train_dataset = Dataset(c.data_path, train_dataset = Dataset(
c.meta_file_train, c.data_path,
c.r, c.meta_file_train,
c.text_cleaner, c.r,
ap = ap, c.text_cleaner,
min_seq_len=c.min_seq_len ap=ap,
) min_seq_len=c.min_seq_len)
train_loader = DataLoader(train_dataset, batch_size=c.batch_size, train_loader = DataLoader(
shuffle=False, collate_fn=train_dataset.collate_fn, train_dataset,
drop_last=False, num_workers=c.num_loader_workers, batch_size=c.batch_size,
pin_memory=True) shuffle=False,
collate_fn=train_dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers,
pin_memory=True)
if c.run_eval: if c.run_eval:
val_dataset = Dataset(c.data_path, val_dataset = Dataset(
c.meta_file_val, c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap)
c.r,
c.text_cleaner,
ap = ap
)
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size, val_loader = DataLoader(
shuffle=False, collate_fn=val_dataset.collate_fn, val_dataset,
drop_last=False, num_workers=4, batch_size=c.eval_batch_size,
pin_memory=True) shuffle=False,
collate_fn=val_dataset.collate_fn,
drop_last=False,
num_workers=4,
pin_memory=True)
else: else:
val_loader = None val_loader = None
model = Tacotron(c.embedding_size, model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r)
ap.num_freq,
c.num_mels,
c.r)
print(" | > Num output units : {}".format(ap.num_freq), flush=True) print(" | > Num output units : {}".format(ap.num_freq), flush=True)
optimizer = optim.Adam(model.parameters(), lr=c.lr) optimizer = optim.Adam(model.parameters(), lr=c.lr)
@ -394,7 +411,8 @@ def main(args):
for k, v in state.items(): for k, v in state.items():
if torch.is_tensor(v): if torch.is_tensor(v):
state[k] = v.cuda() state[k] = v.cuda()
print(" > Model restored from step %d" % checkpoint['step'], flush=True) print(
" > Model restored from step %d" % checkpoint['step'], flush=True)
start_epoch = checkpoint['step'] // len(train_loader) start_epoch = checkpoint['step'] // len(train_loader)
best_loss = checkpoint['linear_loss'] best_loss = checkpoint['linear_loss']
args.restore_step = checkpoint['step'] args.restore_step = checkpoint['step']
@ -416,22 +434,36 @@ def main(args):
best_loss = float('inf') best_loss = float('inf')
for epoch in range(0, c.epochs): for epoch in range(0, c.epochs):
train_loss, current_step = train(model, criterion, criterion_st, train_loader, optimizer, optimizer_st, ap, epoch) train_loss, current_step = train(model, criterion, criterion_st,
val_loss = evaluate(model, criterion, criterion_st, val_loader, ap, current_step) train_loader, optimizer, optimizer_st,
print(" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(train_loss, val_loss), flush=True) ap, epoch)
best_loss = save_best_model(model, optimizer, train_loss, val_loss = evaluate(model, criterion, criterion_st, val_loader, ap,
best_loss, OUT_PATH, current_step)
current_step, epoch) print(
" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(
train_loss, val_loss),
flush=True)
best_loss = save_best_model(model, optimizer, train_loss, best_loss,
OUT_PATH, current_step, epoch)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--restore_path', type=str, parser.add_argument(
help='Folder path to checkpoints', default=0) '--restore_path',
parser.add_argument('--config_path', type=str, type=str,
help='path to config file for training',) help='Folder path to checkpoints',
parser.add_argument('--debug', type=bool, default=False, default=0)
help='do not ask for git has before run.') parser.add_argument(
'--config_path',
type=str,
help='path to config file for training',
)
parser.add_argument(
'--debug',
type=bool,
default=False,
help='do not ask for git has before run.')
args = parser.parse_args() args = parser.parse_args()
# setup output paths and read configs # setup output paths and read configs

View File

@ -9,10 +9,19 @@ _mel_basis = None
class AudioProcessor(object): class AudioProcessor(object):
def __init__(self,
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms, sample_rate,
frame_length_ms, ref_level_db, num_freq, power, preemphasis, num_mels,
min_mel_freq, max_mel_freq, griffin_lim_iters=None): min_level_db,
frame_shift_ms,
frame_length_ms,
ref_level_db,
num_freq,
power,
preemphasis,
min_mel_freq,
max_mel_freq,
griffin_lim_iters=None):
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.num_mels = num_mels self.num_mels = num_mels
@ -30,7 +39,8 @@ class AudioProcessor(object):
def save_wav(self, wav, path): def save_wav(self, wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav))) wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True) librosa.output.write_wav(
path, wav.astype(np.float), self.sample_rate, norm=True)
def _linear_to_mel(self, spectrogram): def _linear_to_mel(self, spectrogram):
global _mel_basis global _mel_basis
@ -40,8 +50,9 @@ class AudioProcessor(object):
def _build_mel_basis(self, ): def _build_mel_basis(self, ):
n_fft = (self.num_freq - 1) * 2 n_fft = (self.num_freq - 1) * 2
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels) return librosa.filters.mel(
# fmin=self.min_mel_freq, fmax=self.max_mel_freq) self.sample_rate, n_fft, n_mels=self.num_mels)
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
def _normalize(self, S): def _normalize(self, S):
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1) return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
@ -86,9 +97,9 @@ class AudioProcessor(object):
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
# Reconstruct phase # Reconstruct phase
if self.preemphasis != 0: if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
else: else:
return self._griffin_lim(S ** self.power) return self._griffin_lim(S**self.power)
def _griffin_lim(self, S): def _griffin_lim(self, S):
'''Applies Griffin-Lim's raw. '''Applies Griffin-Lim's raw.
@ -113,7 +124,8 @@ class AudioProcessor(object):
def _stft(self, y): def _stft(self, y):
n_fft, hop_length, win_length = self._stft_parameters() n_fft, hop_length, win_length = self._stft_parameters()
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) return librosa.stft(
y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def _istft(self, y): def _istft(self, y):
_, hop_length, win_length = self._stft_parameters() _, hop_length, win_length = self._stft_parameters()

View File

@ -8,11 +8,23 @@ import lws
_mel_basis = None _mel_basis = None
class AudioProcessor(object):
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms, class AudioProcessor(object):
frame_length_ms, ref_level_db, num_freq, power, preemphasis, def __init__(
min_mel_freq, max_mel_freq, griffin_lim_iters=None, ): self,
sample_rate,
num_mels,
min_level_db,
frame_shift_ms,
frame_length_ms,
ref_level_db,
num_freq,
power,
preemphasis,
min_mel_freq,
max_mel_freq,
griffin_lim_iters=None,
):
print(" > Setting up Audio Processor...") print(" > Setting up Audio Processor...")
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.num_mels = num_mels self.num_mels = num_mels
@ -25,14 +37,15 @@ class AudioProcessor(object):
self.min_mel_freq = min_mel_freq self.min_mel_freq = min_mel_freq
self.max_mel_freq = max_mel_freq self.max_mel_freq = max_mel_freq
self.griffin_lim_iters = griffin_lim_iters self.griffin_lim_iters = griffin_lim_iters
self.preemphasis =preemphasis self.preemphasis = preemphasis
self.n_fft, self.hop_length, self.win_length = self._stft_parameters() self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
if preemphasis == 0: if preemphasis == 0:
print(" | > Preemphasis is deactive.") print(" | > Preemphasis is deactive.")
def save_wav(self, wav, path): def save_wav(self, wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav))) wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True) librosa.output.write_wav(
path, wav.astype(np.float), self.sample_rate, norm=True)
def _stft_parameters(self, ): def _stft_parameters(self, ):
n_fft = int((self.num_freq - 1) * 2) n_fft = int((self.num_freq - 1) * 2)
@ -44,14 +57,21 @@ class AudioProcessor(object):
if n_fft % win_length != 0: if n_fft % win_length != 0:
win_length = n_fft / 2 win_length = n_fft / 2
print(" | > win_length is set to default ({}).".format(win_length)) print(" | > win_length is set to default ({}).".format(win_length))
print(" | > fft size: {}, hop length: {}, win length: {}".format(n_fft, hop_length, win_length)) print(" | > fft size: {}, hop length: {}, win length: {}".format(
n_fft, hop_length, win_length))
return int(n_fft), int(hop_length), int(win_length) return int(n_fft), int(hop_length), int(win_length)
def _lws_processor(self): def _lws_processor(self):
try: try:
return lws.lws(self.win_length, self.hop_length, fftsize=self.n_fft, mode="speech") return lws.lws(
self.win_length,
self.hop_length,
fftsize=self.n_fft,
mode="speech")
except: except:
raise RuntimeError(" !! WindowLength({}) is not multiple of HopLength({}).".format(self.win_length, self.hop_length)) raise RuntimeError(
" !! WindowLength({}) is not multiple of HopLength({}).".
format(self.win_length, self.hop_length))
def _amp_to_db(self, x): def _amp_to_db(self, x):
min_level = np.exp(self.min_level_db / 20 * np.log(10)) min_level = np.exp(self.min_level_db / 20 * np.log(10))
@ -96,7 +116,7 @@ class AudioProcessor(object):
S = self._denormalize(spectrogram) S = self._denormalize(spectrogram)
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
processor = self._lws_processor() processor = self._lws_processor()
D = processor.run_lws(S.astype(np.float64).T ** self.power) D = processor.run_lws(S.astype(np.float64).T**self.power)
y = processor.istft(D).astype(np.float32) y = processor.istft(D).astype(np.float32)
# Reconstruct phase # Reconstruct phase
if self.preemphasis: if self.preemphasis:
@ -111,7 +131,10 @@ class AudioProcessor(object):
return np.dot(_mel_basis, spectrogram) return np.dot(_mel_basis, spectrogram)
def _build_mel_basis(self, ): def _build_mel_basis(self, ):
return librosa.filters.mel(self.sample_rate, self.n_fft, n_mels=self.num_mels) return librosa.filters.mel(
self.sample_rate, self.n_fft, n_mels=self.num_mels)
# fmin=self.min_mel_freq, fmax=self.max_mel_freq) # fmin=self.min_mel_freq, fmax=self.max_mel_freq)
def melspectrogram(self, y): def melspectrogram(self, y):

View File

@ -4,9 +4,8 @@ import numpy as np
def _pad_data(x, length): def _pad_data(x, length):
_pad = 0 _pad = 0
assert x.ndim == 1 assert x.ndim == 1
return np.pad(x, (0, length - x.shape[0]), return np.pad(
mode='constant', x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
constant_values=_pad)
def prepare_data(inputs): def prepare_data(inputs):
@ -17,8 +16,10 @@ def prepare_data(inputs):
def _pad_tensor(x, length): def _pad_tensor(x, length):
_pad = 0 _pad = 0
assert x.ndim == 2 assert x.ndim == 2
x = np.pad(x, [[0, 0], [0, length - x.shape[1]]], x = np.pad(
mode='constant', constant_values=_pad) x, [[0, 0], [0, length - x.shape[1]]],
mode='constant',
constant_values=_pad)
return x return x
@ -32,7 +33,8 @@ def prepare_tensor(inputs, out_steps):
def _pad_stop_target(x, length): def _pad_stop_target(x, length):
_pad = 1. _pad = 1.
assert x.ndim == 1 assert x.ndim == 1
return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) return np.pad(
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
def prepare_stop_target(inputs, out_steps): def prepare_stop_target(inputs, out_steps):
@ -44,6 +46,7 @@ def prepare_stop_target(inputs, out_steps):
def pad_per_step(inputs, pad_len): def pad_per_step(inputs, pad_len):
timesteps = inputs.shape[-1] timesteps = inputs.shape[-1]
return np.pad(inputs, [[0, 0], [0, 0], return np.pad(
[0, pad_len]], inputs, [[0, 0], [0, 0], [0, pad_len]],
mode='constant', constant_values=0.0) mode='constant',
constant_values=0.0)

View File

@ -28,10 +28,13 @@ def load_config(config_path):
def get_commit_hash(): def get_commit_hash():
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
try: try:
subprocess.check_output(['git', 'diff-index', '--quiet', 'HEAD']) # Verify client is clean subprocess.check_output(['git', 'diff-index', '--quiet',
'HEAD']) # Verify client is clean
except: except:
raise RuntimeError(" !! Commit before training to get the commit hash.") raise RuntimeError(
commit = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode().strip() " !! Commit before training to get the commit hash.")
commit = subprocess.check_output(['git', 'rev-parse', '--short',
'HEAD']).decode().strip()
print(' > Git Hash: {}'.format(commit)) print(' > Git Hash: {}'.format(commit))
return commit return commit
@ -43,7 +46,8 @@ def create_experiment_folder(root_path, model_name, debug):
commit_hash = 'debug' commit_hash = 'debug'
else: else:
commit_hash = get_commit_hash() commit_hash = get_commit_hash()
output_folder = os.path.join(root_path, date_str + '-' + model_name + '-' + commit_hash) output_folder = os.path.join(
root_path, date_str + '-' + model_name + '-' + commit_hash)
os.makedirs(output_folder, exist_ok=True) os.makedirs(output_folder, exist_ok=True)
print(" > Experiment folder: {}".format(output_folder)) print(" > Experiment folder: {}".format(output_folder))
return output_folder return output_folder
@ -52,7 +56,7 @@ def create_experiment_folder(root_path, model_name, debug):
def remove_experiment_folder(experiment_path): def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder""" """Check folder if there is a checkpoint, otherwise remove the folder"""
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar") checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
if len(checkpoint_files) < 1: if len(checkpoint_files) < 1:
if os.path.exists(experiment_path): if os.path.exists(experiment_path):
shutil.rmtree(experiment_path) shutil.rmtree(experiment_path)
@ -86,13 +90,15 @@ def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path,
print(" | | > Checkpoint saving : {}".format(checkpoint_path)) print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = _trim_model_state_dict(model.state_dict()) new_state_dict = _trim_model_state_dict(model.state_dict())
state = {'model': new_state_dict, state = {
'optimizer': optimizer.state_dict(), 'model': new_state_dict,
'optimizer_st': optimizer_st.state_dict(), 'optimizer': optimizer.state_dict(),
'step': current_step, 'optimizer_st': optimizer_st.state_dict(),
'epoch': epoch, 'step': current_step,
'linear_loss': model_loss, 'epoch': epoch,
'date': datetime.date.today().strftime("%B %d, %Y")} 'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y")
}
torch.save(state, checkpoint_path) torch.save(state, checkpoint_path)
@ -100,12 +106,14 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
current_step, epoch): current_step, epoch):
if model_loss < best_loss: if model_loss < best_loss:
new_state_dict = _trim_model_state_dict(model.state_dict()) new_state_dict = _trim_model_state_dict(model.state_dict())
state = {'model': new_state_dict, state = {
'optimizer': optimizer.state_dict(), 'model': new_state_dict,
'step': current_step, 'optimizer': optimizer.state_dict(),
'epoch': epoch, 'step': current_step,
'linear_loss': model_loss, 'epoch': epoch,
'date': datetime.date.today().strftime("%B %d, %Y")} 'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y")
}
best_loss = model_loss best_loss = model_loss
bestmodel_path = 'best_model.pth.tar' bestmodel_path = 'best_model.pth.tar'
bestmodel_path = os.path.join(out_path, bestmodel_path) bestmodel_path = os.path.join(out_path, bestmodel_path)
@ -161,12 +169,12 @@ def sequence_mask(sequence_length, max_len=None):
def synthesis(model, ap, text, use_cuda, text_cleaner): def synthesis(model, ap, text, use_cuda, text_cleaner):
text_cleaner = [text_cleaner] text_cleaner = [text_cleaner]
seq = np.array(text_to_sequence(text, text_cleaner)) seq = np.array(text_to_sequence(text, text_cleaner))
chars_var = torch.from_numpy(seq).unsqueeze(0) chars_var = torch.from_numpy(seq).unsqueeze(0)
if use_cuda: if use_cuda:
chars_var = chars_var.cuda().long() chars_var = chars_var.cuda().long()
_, linear_out, alignments, _ = model.forward(chars_var) _, linear_out, alignments, _ = model.forward(chars_var)
linear_out = linear_out[0].data.cpu().numpy() linear_out = linear_out[0].data.cpu().numpy()
wav = ap.inv_spectrogram(linear_out.T) wav = ap.inv_spectrogram(linear_out.T)
return wav, linear_out, alignments return wav, linear_out, alignments

View File

@ -4,7 +4,6 @@ import re
from utils.text import cleaners from utils.text import cleaners
from utils.text.symbols import symbols from utils.text.symbols import symbols
# Mappings from symbol to numeric ID and vice versa: # Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)} _symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)} _id_to_symbol = {i: s for i, s in enumerate(symbols)}

View File

@ -14,31 +14,31 @@ import re
from unidecode import unidecode from unidecode import unidecode
from .numbers import normalize_numbers from .numbers import normalize_numbers
# Regular expression matching whitespace: # Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+') _whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations: # List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
('mrs', 'misess'), for x in [
('mr', 'mister'), ('mrs', 'misess'),
('dr', 'doctor'), ('mr', 'mister'),
('st', 'saint'), ('dr', 'doctor'),
('co', 'company'), ('st', 'saint'),
('jr', 'junior'), ('co', 'company'),
('maj', 'major'), ('jr', 'junior'),
('gen', 'general'), ('maj', 'major'),
('drs', 'doctors'), ('gen', 'general'),
('rev', 'reverend'), ('drs', 'doctors'),
('lt', 'lieutenant'), ('rev', 'reverend'),
('hon', 'honorable'), ('lt', 'lieutenant'),
('sgt', 'sergeant'), ('hon', 'honorable'),
('capt', 'captain'), ('sgt', 'sergeant'),
('esq', 'esquire'), ('capt', 'captain'),
('ltd', 'limited'), ('esq', 'esquire'),
('col', 'colonel'), ('ltd', 'limited'),
('ft', 'fort'), ('col', 'colonel'),
]] ('ft', 'fort'),
]]
def expand_abbreviations(text): def expand_abbreviations(text):

View File

@ -1,17 +1,16 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import re import re
valid_symbols = [ valid_symbols = [
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0',
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG',
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH',
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W',
'Y', 'Z', 'ZH'
] ]
_valid_symbol_set = set(valid_symbols) _valid_symbol_set = set(valid_symbols)
@ -27,8 +26,10 @@ class CMUDict:
else: else:
entries = _parse_cmudict(file_or_path) entries = _parse_cmudict(file_or_path)
if not keep_ambiguous: if not keep_ambiguous:
entries = {word: pron for word, entries = {
pron in entries.items() if len(pron) == 1} word: pron
for word, pron in entries.items() if len(pron) == 1
}
self._entries = entries self._entries = entries
def __len__(self): def __len__(self):

View File

@ -8,61 +8,45 @@ _ordinal_re = re.compile(r'([0-9]+)(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+') _number_re = re.compile(r'[0-9]+')
_units = [ _units = [
'', '', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine',
'one', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen',
'two', 'seventeen', 'eighteen', 'nineteen'
'three',
'four',
'five',
'six',
'seven',
'eight',
'nine',
'ten',
'eleven',
'twelve',
'thirteen',
'fourteen',
'fifteen',
'sixteen',
'seventeen',
'eighteen',
'nineteen'
] ]
_tens = [ _tens = [
'', '',
'ten', 'ten',
'twenty', 'twenty',
'thirty', 'thirty',
'forty', 'forty',
'fifty', 'fifty',
'sixty', 'sixty',
'seventy', 'seventy',
'eighty', 'eighty',
'ninety', 'ninety',
] ]
_digit_groups = [ _digit_groups = [
'', '',
'thousand', 'thousand',
'million', 'million',
'billion', 'billion',
'trillion', 'trillion',
'quadrillion', 'quadrillion',
] ]
_ordinal_suffixes = [ _ordinal_suffixes = [
('one', 'first'), ('one', 'first'),
('two', 'second'), ('two', 'second'),
('three', 'third'), ('three', 'third'),
('five', 'fifth'), ('five', 'fifth'),
('eight', 'eighth'), ('eight', 'eighth'),
('nine', 'ninth'), ('nine', 'ninth'),
('twelve', 'twelfth'), ('twelve', 'twelfth'),
('ty', 'tieth'), ('ty', 'tieth'),
] ]
def _remove_commas(m): def _remove_commas(m):
return m.group(1).replace(',', '') return m.group(1).replace(',', '')
@ -114,7 +98,7 @@ def _standard_number_to_words(n, digit_group):
def _number_to_words(n): def _number_to_words(n):
# Handle special cases first, then go to the standard case: # Handle special cases first, then go to the standard case:
if n >= 1000000000000000000: if n >= 1000000000000000000:
return str(n) # Too large, just return the digits return str(n) # Too large, just return the digits
elif n == 0: elif n == 0:
return 'zero' return 'zero'
elif n % 100 == 0 and n % 1000 != 0 and n < 3000: elif n % 100 == 0 and n % 1000 != 0 and n < 3000:

View File

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
''' '''
Defines the set of symbols used in text input to the model. Defines the set of symbols used in text input to the model.
@ -19,6 +17,5 @@ _arpabet = ['@' + s for s in cmudict.valid_symbols]
# Export all symbols: # Export all symbols:
symbols = [_pad, _eos] + list(_characters) + _arpabet symbols = [_pad, _eos] + list(_characters) + _arpabet
if __name__ == '__main__': if __name__ == '__main__':
print(symbols) print(symbols)

View File

@ -6,8 +6,8 @@ import matplotlib.pyplot as plt
def plot_alignment(alignment, info=None): def plot_alignment(alignment, info=None):
fig, ax = plt.subplots(figsize=(16, 10)) fig, ax = plt.subplots(figsize=(16, 10))
im = ax.imshow(alignment.T, aspect='auto', origin='lower', im = ax.imshow(
interpolation='none') alignment.T, aspect='auto', origin='lower', interpolation='none')
fig.colorbar(im, ax=ax) fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep' xlabel = 'Decoder timestep'
if info is not None: if info is not None:
@ -17,7 +17,7 @@ def plot_alignment(alignment, info=None):
plt.tight_layout() plt.tight_layout()
fig.canvas.draw() fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
plt.close() plt.close()
return data return data
@ -30,6 +30,6 @@ def plot_spectrogram(linear_output, audio):
plt.tight_layout() plt.tight_layout()
fig.canvas.draw() fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
plt.close() plt.close()
return data return data