From b5f2181e04ebe293f39c49a4304861c52218b7c4 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 23 Feb 2018 08:35:53 -0800 Subject: [PATCH 01/13] teacher forcing with combining --- layers/tacotron.py | 22 ++++++++++++++++------ models/tacotron.py | 7 ++++--- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/layers/tacotron.py b/layers/tacotron.py index 9b31d02b..b977c51e 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -153,7 +153,7 @@ class CBHG(nn.Module): out = conv1d(x) out = out[:, :, :T] outs.append(out) - + x = torch.cat(outs, dim=1) assert x.size(1) == self.in_features * len(self.conv1d_banks) @@ -236,7 +236,7 @@ class Decoder(nn.Module): # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * r) - def forward(self, inputs, memory=None, memory_lengths=None): + def forward(self, inputs, memory=None, input_lengths=None): r""" Decoder forward step. @@ -247,7 +247,7 @@ class Decoder(nn.Module): inputs: Encoder outputs. memory: Decoder memory (autoregression. If None (at eval-time), decoder outputs are used as decoder inputs. - memory_lengths: Encoder output (memory) lengths. If not None, used for + input_lengths: Encoder output (memory) lengths. If not None, used for attention masking. Shapes: @@ -258,8 +258,8 @@ class Decoder(nn.Module): # TODO: take this segment into Attention module. processed_inputs = self.input_layer(inputs) - if memory_lengths is not None: - mask = get_mask_from_lengths(processed_inputs, memory_lengths) + if input_lengths is not None: + mask = get_mask_from_lengths(processed_inputs, input_lengths) else: mask = None @@ -300,7 +300,17 @@ class Decoder(nn.Module): memory_input = initial_memory while True: if t > 0: - memory_input = outputs[-1] if greedy else memory[t - 1] + # using harmonized teacher-forcing. + # from https://arxiv.org/abs/1707.06588 + if greedy: + memory_input = outputs[-1] + else: + # combine prev. model output and prev. real target + memory_input = torch.div(outputs[-1] + memory[t-1], 2.0) + # add a random noise + memory_input += torch.autograd.Variable( + torch.randn(memory_input.size())).type_as(memory_input) + # Prenet processed_memory = self.prenet(memory_input) diff --git a/models/tacotron.py b/models/tacotron.py index 6e67edf8..c6218e40 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -34,12 +34,13 @@ class Tacotron(nn.Module): encoder_outputs = self.encoder(inputs) if self.use_memory_mask: - memory_lengths = input_lengths + input_lengths = input_lengths else: - memory_lengths = None + input_lengths = None + # (B, T', mel_dim*r) mel_outputs, alignments = self.decoder( - encoder_outputs, mel_specs, memory_lengths=memory_lengths) + encoder_outputs, mel_specs, input_lengths=input_lengths) # Post net processing below From 56f8b2d19f90807ce7e5801b780a8d4a6da96541 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 26 Feb 2018 05:33:54 -0800 Subject: [PATCH 02/13] Harmonized teacher-forcing --- config.json | 2 +- layers/tacotron.py | 10 +++++++--- train.py | 23 +++++++++++++++-------- utils/generic_utils.py | 1 + 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/config.json b/config.json index 01860746..ebb0187d 100644 --- a/config.json +++ b/config.json @@ -12,7 +12,7 @@ "text_cleaner": "english_cleaners", "epochs": 2000, - "lr": 0.003, + "lr": 0.001, "batch_size": 180, "r": 5, diff --git a/layers/tacotron.py b/layers/tacotron.py index b977c51e..c43d3dd3 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -307,9 +307,13 @@ class Decoder(nn.Module): else: # combine prev. model output and prev. real target memory_input = torch.div(outputs[-1] + memory[t-1], 2.0) + memory_input = torch.nn.functional.dropout(memory_input, + 0.1, + training=True) # add a random noise - memory_input += torch.autograd.Variable( - torch.randn(memory_input.size())).type_as(memory_input) + noise = torch.autograd.Variable( + memory_input.data.new(ins.size()).normal_(0.0, 1.0)) + memory_input = memory_input + noise # Prenet processed_memory = self.prenet(memory_input) @@ -360,5 +364,5 @@ class Decoder(nn.Module): return outputs, alignments -def is_end_of_frames(output, eps=0.1): #0.2 +def is_end_of_frames(output, eps=0.2): #0.2 return (output.data <= eps).all() diff --git a/train.py b/train.py index 0d432cce..99b47a9b 100644 --- a/train.py +++ b/train.py @@ -90,9 +90,6 @@ def main(args): # onnx.export(model, dummy_input, model_proto_path, verbose=True) # tb.add_graph_onnx(model_proto_path) - if use_cuda: - model = nn.DataParallel(model.cuda()) - optimizer = optim.Adam(model.parameters(), lr=c.lr) if args.restore_step: @@ -103,10 +100,20 @@ def main(args): print("\n > Model restored from step %d\n" % args.restore_step) start_epoch = checkpoint['step'] // len(dataloader) best_loss = checkpoint['linear_loss'] - else: + elif args.restore_path: + checkpoint = torch.load(args.restore_path) + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("\n > Model restored from step %d\n" % checkpoint['step']) + start_epoch = checkpoint['step'] // len(dataloader) + best_loss = checkpoint['linear_loss'] start_epoch = 0 + else: print("\n > Starting a new training") + if use_cuda: + model = nn.DataParallel(model.cuda()) + num_params = count_parameters(model) print(" | > Model has {} parameters".format(num_params)) @@ -142,9 +149,9 @@ def main(args): current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1 # setup lr - current_lr = lr_decay(c.lr, current_step) - for params_group in optimizer.param_groups: - params_group['lr'] = current_lr + # current_lr = lr_decay(c.lr, current_step) + # for params_group in optimizer.param_groups: + # params_group['lr'] = current_lr optimizer.zero_grad() @@ -192,7 +199,7 @@ def main(args): # loss = loss.cuda() loss.backward() - grad_norm = nn.utils.clip_grad_norm(model.parameters(), 1.) ## TODO: maybe no need + grad_norm = nn.utils.clip_grad_norm(model.parameters(), 0.5) ## TODO: maybe no need optimizer.step() step_time = time.time() - start_time diff --git a/utils/generic_utils.py b/utils/generic_utils.py index ca32060c..0877056b 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -7,6 +7,7 @@ import datetime import json import torch import numpy as np +from collections import OrderedDict class AttrDict(dict): From 1d684ea0e88ff7d0ae9099f1f9d995dfc36d18e9 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 27 Feb 2018 06:25:28 -0800 Subject: [PATCH 03/13] ReadMe update --- README.md | 27 +++++++++++++++++++++------ config.json | 2 +- layers/tacotron.py | 2 +- train.py | 3 ++- utils/generic_utils.py | 12 ++++++------ 5 files changed, 31 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 0f68c30a..8e35d798 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,9 @@ # TTS (Work in Progress...) +TTS targets a Text2Speech engine lightweight in computation with hight quality speech construction. -Here we have pytorch implementation of: -- Tacotron: [A Fully End-to-End Text-To-Speech Synthesis Model](https://arxiv.org/abs/1703.10135). -- Tacotron2 (TODO): [Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions](https://arxiv.org/pdf/1712.05884.pdf) +Here we have pytorch implementation of Tacotron: [A Fully End-to-End Text-To-Speech Synthesis Model](https://arxiv.org/abs/1703.10135) as the start point. We plan to improve the model by the recent updated at the field. -At the end, it should be easy to add new models and try different architectures. - -You can find [here](https://www.evernote.com/shard/s146/sh/9544e7e9-d372-4610-a7b7-3ddcb63d5dac/d01d33837dab625229dec3cfb4cfb887) a brief note about possible TTS architectures and their comparisons. +You can find [here](https://www.evernote.com/shard/s146/sh/9544e7e9-d372-4610-a7b7-3ddcb63d5dac/d01d33837dab625229dec3cfb4cfb887) a brief note pointing possible TTS architectures and their comparisons. ## Requirements Highly recommended to use [miniconda](https://conda.io/miniconda.html) for easier installation. @@ -72,3 +69,21 @@ Best way to test your pretrained network is to use the Notebook under ```noteboo ## Contribution Any kind of contribution is highly welcome as we are propelled by the open-source spirit. If you like to add or edit things in code, please also consider to write tests to verify your segment so that we can be sure things are on the track as this repo gets bigger. + +## TODO +- Make the default Tacotron architecture functional with reasonable fidelity. [DONE] +- Update the architecture with the latest improvements at the field. (e.g. Monotonic Attention, World Vocoder) + - Using harmonized teacher forcing proposed by + - Update the attention module with a monotonic alternative. (e.g GMM attention, Window based attention) + - References: + - [Efficient Neural Audio Synthesis](https://arxiv.org/pdf/1506.07503.pdf) + - [Attention-Based models for speech recognition](https://arxiv.org/pdf/1308.0850.pdf) + - [Char2Wav](https://openreview.net/pdf?id=B1VWyySKx) + - [VoiceLoop](https://arxiv.org/pdf/1707.06588.pdf) +- Simplify the architecture and push the limits of performance vs efficiency. +- Improve vocoder part of the network. + - Possible Solutions: + - WORLD vocoder + - [WaveRNN](https://128.84.21.199/pdf/1802.08435.pdf) + - [Faster WaveNet](https://arxiv.org/abs/1611.09482) + - [Parallel WaveNet](https://arxiv.org/abs/1711.10433) diff --git a/config.json b/config.json index ebb0187d..1a617ea8 100644 --- a/config.json +++ b/config.json @@ -12,7 +12,7 @@ "text_cleaner": "english_cleaners", "epochs": 2000, - "lr": 0.001, + "lr": 0.005, "batch_size": 180, "r": 5, diff --git a/layers/tacotron.py b/layers/tacotron.py index c43d3dd3..ac348017 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -312,7 +312,7 @@ class Decoder(nn.Module): training=True) # add a random noise noise = torch.autograd.Variable( - memory_input.data.new(ins.size()).normal_(0.0, 1.0)) + memory_input.data.new(memory_input.size()).normal_(0.0, 1.0)) memory_input = memory_input + noise # Prenet diff --git a/train.py b/train.py index 99b47a9b..230f0bef 100644 --- a/train.py +++ b/train.py @@ -132,7 +132,8 @@ def main(args): #lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay, # patience=c.lr_patience, verbose=True) epoch_time = 0 - best_loss = float('inf') + if 'best_loss' not in locals(): + best_loss = float('inf') for epoch in range(0, c.epochs): print("\n | > Epoch {}/{}".format(epoch, c.epochs)) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 0877056b..ed9661ff 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -198,13 +198,13 @@ class Progbar(object): eta_format = '%ds' % eta info = ' - ETA: %s' % eta_format + + if time_per_unit >= 1: + info += ' %.0fs/step' % time_per_unit + elif time_per_unit >= 1e-3: + info += ' %.0fms/step' % (time_per_unit * 1e3) else: - if time_per_unit >= 1: - info += ' %.0fs/step' % time_per_unit - elif time_per_unit >= 1e-3: - info += ' %.0fms/step' % (time_per_unit * 1e3) - else: - info += ' %.0fus/step' % (time_per_unit * 1e6) + info += ' %.0fus/step' % (time_per_unit * 1e6) for k in self.unique_values: info += ' - %s:' % k From b21fa4dd44cf41e3271ce67b3d4d4697ecbf1bd7 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 27 Feb 2018 06:27:33 -0800 Subject: [PATCH 04/13] Readme update --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8e35d798..9931b40c 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ Best way to test your pretrained network is to use the Notebook under ```noteboo Any kind of contribution is highly welcome as we are propelled by the open-source spirit. If you like to add or edit things in code, please also consider to write tests to verify your segment so that we can be sure things are on the track as this repo gets bigger. ## TODO -- Make the default Tacotron architecture functional with reasonable fidelity. [DONE] +- [DONE] Make the default Tacotron architecture functional with reasonable fidelity. - Update the architecture with the latest improvements at the field. (e.g. Monotonic Attention, World Vocoder) - Using harmonized teacher forcing proposed by - Update the attention module with a monotonic alternative. (e.g GMM attention, Window based attention) From d6b2af7ca9bb347ac50ca4846a2b303c9a5658f5 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 27 Feb 2018 07:31:07 -0800 Subject: [PATCH 05/13] check gradients for big errorenous changes --- config.json | 1 + train.py | 16 +++++++++------- utils/generic_utils.py | 17 +++++++++++++++-- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/config.json b/config.json index 1a617ea8..3505e114 100644 --- a/config.json +++ b/config.json @@ -13,6 +13,7 @@ "epochs": 2000, "lr": 0.005, + "warmup_steps": 4000, "batch_size": 180, "r": 5, diff --git a/train.py b/train.py index 230f0bef..9b8e8c9f 100644 --- a/train.py +++ b/train.py @@ -21,7 +21,7 @@ from tensorboardX import SummaryWriter from utils.generic_utils import (Progbar, remove_experiment_folder, create_experiment_folder, save_checkpoint, save_best_model, load_config, lr_decay, - count_parameters) + count_parameters, check_update) from utils.model import get_param_size from utils.visual import plot_alignment, plot_spectrogram from datasets.LJSpeech import LJSpeechDataset @@ -150,9 +150,9 @@ def main(args): current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1 # setup lr - # current_lr = lr_decay(c.lr, current_step) - # for params_group in optimizer.param_groups: - # params_group['lr'] = current_lr + current_lr = lr_decay(c.lr, current_step, c.warmup_steps) + for params_group in optimizer.param_groups: + params_group['lr'] = current_lr optimizer.zero_grad() @@ -197,10 +197,13 @@ def main(args): + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_spec_var[: ,: ,:n_priority_freq]) loss = mel_loss + linear_loss - # loss = loss.cuda() loss.backward() - grad_norm = nn.utils.clip_grad_norm(model.parameters(), 0.5) ## TODO: maybe no need + grad_norm, skip_flag = check_update(model, 0.5, 100) + if skip_flag: + optimizer.zero_grad() + print(" | > Iteration skipped!!") + continue optimizer.step() step_time = time.time() - start_time @@ -265,7 +268,6 @@ def main(args): best_loss, OUT_PATH, current_step, epoch) - #lr_scheduler.step(loss.data[0]) tb.add_scalar('Time/EpochTime', epoch_time, epoch) epoch_time = 0 diff --git a/utils/generic_utils.py b/utils/generic_utils.py index ed9661ff..4832ec44 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -95,8 +95,21 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, return best_loss -def lr_decay(init_lr, global_step): - warmup_steps = 4000.0 +def check_update(model, grad_clip, grad_top): + r'''Check model gradient against unexpected jumps and failures''' + skip_flag = False + grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), grad_clip) + if np.isinf(grad_norm): + print(" | > Gradient is INF !!") + skip_flag = True + elif grad_norm > grad_top: + print(" | > Gradient is above the top limit !!") + skip_flag = True + return grad_norm, skip_flag + + +def lr_decay(init_lr, global_step, warmup_steps): + r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py''' step = global_step + 1. lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5, step**-0.5) From 81669c1e58faa3e977bf1d7af43c879eb950b459 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 27 Feb 2018 07:32:09 -0800 Subject: [PATCH 06/13] more tests --- tests/generic_utils_text.py | 38 +++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/generic_utils_text.py diff --git a/tests/generic_utils_text.py b/tests/generic_utils_text.py new file mode 100644 index 00000000..0461d263 --- /dev/null +++ b/tests/generic_utils_text.py @@ -0,0 +1,38 @@ +import unittest +import torch as T + +from TTS.utils.generic_utils import save_checkpoint, save_best_model +from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder + +OUT_PATH = '/tmp/test.pth.tar' + +class ModelSavingTests(unittest.TestCase): + + def save_checkpoint_test(self): + # create a dummy model + model = Prenet(128, out_features=[256, 128]) + model = T.nn.DataParallel(layer) + + # save the model + save_checkpoint(model, None, 100, + OUTPATH, 1, 1) + + # load the model to CPU + model_dict = torch.load(MODEL_PATH, map_location=lambda storage, + loc: storage) + model.load_state_dict(model_dict['model']) + + def save_best_model_test(self): + # create a dummy model + model = Prenet(256, out_features=[256, 256]) + model = T.nn.DataParallel(layer) + + # save the model + best_loss = save_best_model(model, None, 0, + 100, OUT_PATH, + 10, 1) + + # load the model to CPU + model_dict = torch.load(MODEL_PATH, map_location=lambda storage, + loc: storage) + model.load_state_dict(model_dict['model']) From 793563b586f44ec06ffbe1193d939e751c1e4737 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 2 Mar 2018 05:42:23 -0800 Subject: [PATCH 07/13] Remove pandas rom dataset --- config.json | 2 +- datasets/LJSpeech.py | 8 +++++--- train.py | 4 +++- utils/visual.py | 2 +- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/config.json b/config.json index 3505e114..d9ffc5c8 100644 --- a/config.json +++ b/config.json @@ -12,7 +12,7 @@ "text_cleaner": "english_cleaners", "epochs": 2000, - "lr": 0.005, + "lr": 0.0006, "warmup_steps": 4000, "batch_size": 180, "r": 5, diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index 3fe07078..81c2c9e9 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -1,4 +1,3 @@ -import pandas as pd import os import numpy as np import collections @@ -16,7 +15,10 @@ class LJSpeechDataset(Dataset): def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate, text_cleaner, num_mels, min_level_db, frame_shift_ms, frame_length_ms, preemphasis, ref_level_db, num_freq, power): - self.frames = pd.read_csv(csv_file, sep='|', header=None) + + f = open(csv_file, "r") + self.frames = [line.split('|') for line in f] + f.close() self.root_dir = root_dir self.outputs_per_step = outputs_per_step self.sample_rate = sample_rate @@ -40,7 +42,7 @@ class LJSpeechDataset(Dataset): def __getitem__(self, idx): wav_name = os.path.join(self.root_dir, self.frames.ix[idx, 0]) + '.wav' - text = self.frames.ix[idx, 1] + text = self.frames[idx][1] text = np.asarray(text_to_sequence(text, [self.cleaners]), dtype=np.int32) wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32) sample = {'text': text, 'wav': wav, 'item_idx': self.frames.ix[idx, 0]} diff --git a/train.py b/train.py index 9b8e8c9f..09c0e402 100644 --- a/train.py +++ b/train.py @@ -73,7 +73,8 @@ def main(args): dataloader = DataLoader(dataset, batch_size=c.batch_size, shuffle=True, collate_fn=dataset.collate_fn, - drop_last=True, num_workers=c.num_loader_workers) + drop_last=True, num_workers=c.num_loader_workers, + pin_memory=True) # setup the model model = Tacotron(c.embedding_size, @@ -108,6 +109,7 @@ def main(args): start_epoch = checkpoint['step'] // len(dataloader) best_loss = checkpoint['linear_loss'] start_epoch = 0 + args.restore_step = checkpoint['step'] else: print("\n > Starting a new training") diff --git a/utils/visual.py b/utils/visual.py index 935ed369..b0143fc9 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt def plot_alignment(alignment, info=None): - fig, ax = plt.subplots() + fig, ax = plt.subplots(figsize=(16,10)) im = ax.imshow(alignment.T, aspect='auto', origin='lower', interpolation='none') fig.colorbar(im, ax=ax) From 021ac3978d92657665453da811e7ceae68952024 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 2 Mar 2018 07:54:35 -0800 Subject: [PATCH 08/13] split train and validation steps --- config.json | 5 +- datasets/LJSpeech.py | 13 +- train.py | 483 ++++++++++++++++++++++++++----------------- 3 files changed, 297 insertions(+), 204 deletions(-) diff --git a/config.json b/config.json index d9ffc5c8..cd3aef72 100644 --- a/config.json +++ b/config.json @@ -20,11 +20,10 @@ "griffin_lim_iters": 60, "power": 1.5, - "num_loader_workers": 32, + "num_loader_workers": 16, "checkpoint": false, "save_step": 69, - "data_path": "/data/shared/KeithIto/LJSpeech-1.0", + "data_path": "/run/shm/erogol/LJSpeech-1.0", "output_path": "result", - "log_dir": "/home/erogol/projects/TTS/logs/" } diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index 81c2c9e9..ded16ed5 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -16,16 +16,15 @@ class LJSpeechDataset(Dataset): text_cleaner, num_mels, min_level_db, frame_shift_ms, frame_length_ms, preemphasis, ref_level_db, num_freq, power): - f = open(csv_file, "r") - self.frames = [line.split('|') for line in f] - f.close() + with open(csv_file, "r") as f: + self.frames = [line.split('|') for line in f] + self.frames = self.frames[:256] self.root_dir = root_dir self.outputs_per_step = outputs_per_step self.sample_rate = sample_rate self.cleaners = text_cleaner self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms, - frame_length_ms, preemphasis, ref_level_db, num_freq, power - ) + frame_length_ms, preemphasis, ref_level_db, num_freq, power) print(" > Reading LJSpeech from - {}".format(root_dir)) print(" | > Number of instances : {}".format(len(self.frames))) @@ -41,11 +40,11 @@ class LJSpeechDataset(Dataset): def __getitem__(self, idx): wav_name = os.path.join(self.root_dir, - self.frames.ix[idx, 0]) + '.wav' + self.frames[idx][0]) + '.wav' text = self.frames[idx][1] text = np.asarray(text_to_sequence(text, [self.cleaners]), dtype=np.int32) wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32) - sample = {'text': text, 'wav': wav, 'item_idx': self.frames.ix[idx, 0]} + sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]} return sample def get_dummy_data(self): diff --git a/train.py b/train.py index 09c0e402..c806f965 100644 --- a/train.py +++ b/train.py @@ -27,36 +27,265 @@ from utils.visual import plot_alignment, plot_spectrogram from datasets.LJSpeech import LJSpeechDataset from models.tacotron import Tacotron + use_cuda = torch.cuda.is_available() +parser = argparse.ArgumentParser() +parser.add_argument('--restore_step', type=int, + help='Global step to restore checkpoint', default=0) +parser.add_argument('--restore_path', type=str, + help='Folder path to checkpoints', default=0) +parser.add_argument('--config_path', type=str, + help='path to config file for training',) +args = parser.parse_args() + +# setup output paths and read configs +c = load_config(args.config_path) +_ = os.path.dirname(os.path.realpath(__file__)) +OUT_PATH = os.path.join(_, c.output_path) +OUT_PATH = create_experiment_folder(OUT_PATH) +CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints') +shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json')) + +# save config to tmp place to be loaded by subsequent modules. +file_name = str(os.getpid()) +tmp_path = os.path.join("/tmp/", file_name+'_tts') +pickle.dump(c, open(tmp_path, "wb")) + +# setup tensorboard +LOG_DIR = OUT_PATH +tb = SummaryWriter(LOG_DIR) + + +def signal_handler(signal, frame): + """Ctrl+C handler to remove empty experiment folder""" + print(" !! Pressed Ctrl+C !!") + remove_experiment_folder(OUT_PATH) + sys.exit(1) + + +def train(model, criterion, data_loader, optimizer, epoch): + model = model.train() + epoch_time = 0 + + print(" | > Epoch {}/{}".format(epoch, c.epochs)) + progbar = Progbar(len(data_loader.dataset) / c.batch_size) + n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) + for num_iter, data in enumerate(data_loader): + start_time = time.time() + + # setup input data + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + + current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1 + + # setup lr + current_lr = lr_decay(c.lr, current_step, c.warmup_steps) + for params_group in optimizer.param_groups: + params_group['lr'] = current_lr + + optimizer.zero_grad() + + # convert inputs to variables + text_input_var = Variable(text_input) + mel_spec_var = Variable(mel_input) + linear_spec_var = Variable(linear_input, volatile=True) + + # sort sequence by length for curriculum learning + # TODO: might be unnecessary + sorted_lengths, indices = torch.sort( + text_lengths.view(-1), dim=0, descending=True) + sorted_lengths = sorted_lengths.long().numpy() + text_input_var = text_input_var[indices] + mel_spec_var = mel_spec_var[indices] + linear_spec_var = linear_spec_var[indices] + + # dispatch data to GPU + if use_cuda: + text_input_var = text_input_var.cuda() + mel_spec_var = mel_spec_var.cuda() + linear_spec_var = linear_spec_var.cuda() + + # forward pass + mel_output, linear_output, alignments =\ + model.forward(text_input_var, mel_spec_var, + input_lengths= torch.autograd.Variable(torch.cuda.LongTensor(sorted_lengths))) + + # loss computation + mel_loss = criterion(mel_output, mel_spec_var) + linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \ + + 0.5 * criterion(linear_output[:, :, :n_priority_freq], + linear_spec_var[: ,: ,:n_priority_freq]) + loss = mel_loss + linear_loss + + # backpass and check the grad norm + loss.backward() + grad_norm, skip_flag = check_update(model, 0.5, 100) + if skip_flag: + optimizer.zero_grad() + print(" | > Iteration skipped!!") + continue + optimizer.step() + + step_time = time.time() - start_time + epoch_time += step_time + + # update + progbar.update(num_iter+1, values=[('total_loss', loss.data[0]), + ('linear_loss', linear_loss.data[0]), + ('mel_loss', mel_loss.data[0]), + ('grad_norm', grad_norm)]) + + # Plot Training Iter Stats + tb.add_scalar('TrainIterLoss/TotalLoss', loss.data[0], current_step) + tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.data[0], + current_step) + tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.data[0], current_step) + tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'], + current_step) + tb.add_scalar('Params/GradNorm', grad_norm, current_step) + tb.add_scalar('Time/StepTime', step_time, current_step) + + if current_step % c.save_step == 0: + if c.checkpoint: + # save model + save_checkpoint(model, optimizer, linear_loss.data[0], + OUT_PATH, current_step, epoch) + + # Diagnostic visualizations + const_spec = linear_output[0].data.cpu().numpy() + gt_spec = linear_spec_var[0].data.cpu().numpy() + + const_spec = plot_spectrogram(const_spec, dataset.ap) + gt_spec = plot_spectrogram(gt_spec, dataset.ap) + tb.add_image('Visual/Reconstruction', const_spec, current_step) + tb.add_image('Visual/GroundTruth', gt_spec, current_step) + + align_img = alignments[0].data.cpu().numpy() + align_img = plot_alignment(align_img) + tb.add_image('Visual/Alignment', align_img, current_step) + + # Sample audio + audio_signal = linear_output[0].data.cpu().numpy() + dataset.ap.griffin_lim_iters = 60 + audio_signal = dataset.ap.inv_spectrogram(audio_signal.T) + try: + tb.add_audio('SampleAudio', audio_signal, current_step, + sample_rate=c.sample_rate) + except: + print("\n > Error at audio signal on TB!!") + print(audio_signal.max()) + print(audio_signal.min()) + + avg_linear_loss = np.mean( + progbar.sum_values['linear_loss'][0] / max(1, progbar.sum_values['linear_loss'][1])) + avg_mel_loss = np.mean( + progbar.sum_values['mel_loss'][0] / max(1, progbar.sum_values['mel_loss'][1])) + avg_total_loss = avg_mel_loss + avg_linear_loss + + # Plot Training Epoch Stats + tb.add_scalar('TrainEpochLoss/TotalLoss', loss.data[0], current_step) + tb.add_scalar('TrainEpochLoss/LinearLoss', linear_loss.data[0], current_step) + tb.add_scalar('TrainEpochLoss/MelLoss', mel_loss.data[0], current_step) + tb.add_scalar('Time/EpochTime', epoch_time, epoch) + epoch_time = 0 + + return avg_linear_loss, current_step + + +def evaluate(model, criterion, data_loader, current_step): + model = model.train() + epoch_time = 0 + + print("\n | > Validation") + n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) + progbar = Progbar(len(data_loader.dataset) / c.batch_size) + + for num_iter, data in enumerate(data_loader): + start_time = time.time() + + # setup input data + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + + # convert inputs to variables + text_input_var = Variable(text_input) + mel_spec_var = Variable(mel_input) + linear_spec_var = Variable(linear_input, volatile=True) + + # dispatch data to GPU + if use_cuda: + text_input_var = text_input_var.cuda() + mel_spec_var = mel_spec_var.cuda() + linear_spec_var = linear_spec_var.cuda() + + # forward pass + mel_output, linear_output, alignments =\ + model.forward(text_input_var, mel_spec_var) + + # loss computation + mel_loss = criterion(mel_output, mel_spec_var) + linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \ + + 0.5 * criterion(linear_output[:, :, :n_priority_freq], + linear_spec_var[: ,: ,:n_priority_freq]) + loss = mel_loss + linear_loss + + step_time = time.time() - start_time + epoch_time += step_time + + # update + progbar.update(num_iter+1, values=[('total_loss', loss.data[0]), + ('linear_loss', linear_loss.data[0]), + ('mel_loss', mel_loss.data[0])]) + + # Diagnostic visualizations + idx = np.random.randint(c.batch_size) + const_spec = linear_output[idx].data.cpu().numpy() + gt_spec = linear_spec_var[idx].data.cpu().numpy() + + const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap) + gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap) + tb.add_image('ValVisual/Reconstruction', const_spec, current_step) + tb.add_image('ValVisual/GroundTruth', gt_spec, current_step) + + align_img = alignments[idx].data.cpu().numpy() + align_img = plot_alignment(align_img) + tb.add_image('ValVisual/ValidationAlignment', align_img, current_step) + + # Sample audio + audio_signal = linear_output[idx].data.cpu().numpy() + data_loader.dataset.ap.griffin_lim_iters = 60 + audio_signal = data_loader.dataset.ap.inv_spectrogram(audio_signal.T) + try: + tb.add_audio('ValSampleAudio', audio_signal, current_step, + sample_rate=c.sample_rate) + except: + print("\n > Error at audio signal on TB!!") + print(audio_signal.max()) + print(audio_signal.min()) + + # compute average losses + avg_linear_loss = np.mean( + progbar.sum_values['linear_loss'][0] / max(1, progbar.sum_values['linear_loss'][1])) + avg_mel_loss = np.mean( + progbar.sum_values['mel_loss'][0] / max(1, progbar.sum_values['mel_loss'][1])) + avg_total_loss = avg_mel_loss + avg_linear_loss + + # Plot Learning Stats + tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step) + tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step) + tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step) + return avg_linear_loss + + def main(args): - # setup output paths and read configs - c = load_config(args.config_path) - _ = os.path.dirname(os.path.realpath(__file__)) - OUT_PATH = os.path.join(_, c.output_path) - OUT_PATH = create_experiment_folder(OUT_PATH) - CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints') - shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json')) - - # save config to tmp place to be loaded by subsequent modules. - file_name = str(os.getpid()) - tmp_path = os.path.join("/tmp/", file_name+'_tts') - pickle.dump(c, open(tmp_path, "wb")) - - # setup tensorboard - LOG_DIR = OUT_PATH - tb = SummaryWriter(LOG_DIR) - - # Ctrl+C handler to remove empty experiment folder - def signal_handler(signal, frame): - print(" !! Pressed Ctrl+C !!") - remove_experiment_folder(OUT_PATH) - sys.exit(1) - signal.signal(signal.SIGINT, signal_handler) - - # Setup the dataset - dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'), + train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'), os.path.join(c.data_path, 'wavs'), c.r, c.sample_rate, @@ -71,27 +300,42 @@ def main(args): c.power ) - dataloader = DataLoader(dataset, batch_size=c.batch_size, - shuffle=True, collate_fn=dataset.collate_fn, + train_loader = DataLoader(train_dataset, batch_size=c.batch_size, + shuffle=True, collate_fn=train_dataset.collate_fn, drop_last=True, num_workers=c.num_loader_workers, pin_memory=True) + + val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'), + os.path.join(c.data_path, 'wavs'), + c.r, + c.sample_rate, + c.text_cleaner, + c.num_mels, + c.min_level_db, + c.frame_shift_ms, + c.frame_length_ms, + c.preemphasis, + c.ref_level_db, + c.num_freq, + c.power + ) + + val_loader = DataLoader(val_dataset, batch_size=c.batch_size, + shuffle=True, collate_fn=val_dataset.collate_fn, + drop_last=True, num_workers= 4, + pin_memory=True) - # setup the model - model = Tacotron(c.embedding_size, c.hidden_size, c.num_mels, c.num_freq, c.r) - # plot model on tensorboard - dummy_input = dataset.get_dummy_data() - - ## TODO: onnx does not support RNN fully yet - # model_proto_path = os.path.join(OUT_PATH, "model.proto") - # onnx.export(model, dummy_input, model_proto_path, verbose=True) - # tb.add_graph_onnx(model_proto_path) - optimizer = optim.Adam(model.parameters(), lr=c.lr) + + if use_cuda: + criterion = nn.L1Loss().cuda() + else: + criterion = nn.L1Loss() if args.restore_step: checkpoint = torch.load(os.path.join( @@ -118,169 +362,20 @@ def main(args): num_params = count_parameters(model) print(" | > Model has {} parameters".format(num_params)) - - model = model.train() - + if not os.path.exists(CHECKPOINT_PATH): os.mkdir(CHECKPOINT_PATH) - - if use_cuda: - criterion = nn.L1Loss().cuda() - else: - criterion = nn.L1Loss() - - n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) - - #lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay, - # patience=c.lr_patience, verbose=True) - epoch_time = 0 + if 'best_loss' not in locals(): best_loss = float('inf') + for epoch in range(0, c.epochs): - - print("\n | > Epoch {}/{}".format(epoch, c.epochs)) - progbar = Progbar(len(dataset) / c.batch_size) - - for num_iter, data in enumerate(dataloader): - start_time = time.time() - - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - - current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1 - - # setup lr - current_lr = lr_decay(c.lr, current_step, c.warmup_steps) - for params_group in optimizer.param_groups: - params_group['lr'] = current_lr - - optimizer.zero_grad() - - # Add a single frame of zeros to Mel Specs for better end detection - #try: - # mel_input = np.concatenate((np.zeros( - # [c.batch_size, 1, c.num_mels], dtype=np.float32), - # mel_input[:, 1:, :]), axis=1) - #except: - # raise TypeError("not same dimension") - - # convert inputs to variables - text_input_var = Variable(text_input) - mel_spec_var = Variable(mel_input) - linear_spec_var = Variable(linear_input, volatile=True) - - # sort sequence by length. - # TODO: might be unnecessary - sorted_lengths, indices = torch.sort( - text_lengths.view(-1), dim=0, descending=True) - sorted_lengths = sorted_lengths.long().numpy() - - text_input_var = text_input_var[indices] - mel_spec_var = mel_spec_var[indices] - linear_spec_var = linear_spec_var[indices] - - if use_cuda: - text_input_var = text_input_var.cuda() - mel_spec_var = mel_spec_var.cuda() - linear_spec_var = linear_spec_var.cuda() - - mel_output, linear_output, alignments =\ - model.forward(text_input_var, mel_spec_var, - input_lengths= torch.autograd.Variable(torch.cuda.LongTensor(sorted_lengths))) - - mel_loss = criterion(mel_output, mel_spec_var) - #linear_loss = torch.abs(linear_output - linear_spec_var) - #linear_loss = 0.5 * \ - #torch.mean(linear_loss) + 0.5 * \ - #torch.mean(linear_loss[:, :n_priority_freq, :]) - linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \ - + 0.5 * criterion(linear_output[:, :, :n_priority_freq], - linear_spec_var[: ,: ,:n_priority_freq]) - loss = mel_loss + linear_loss - - loss.backward() - grad_norm, skip_flag = check_update(model, 0.5, 100) - if skip_flag: - optimizer.zero_grad() - print(" | > Iteration skipped!!") - continue - optimizer.step() - - step_time = time.time() - start_time - epoch_time += step_time - - progbar.update(num_iter+1, values=[('total_loss', loss.data[0]), - ('linear_loss', linear_loss.data[0]), - ('mel_loss', mel_loss.data[0]), - ('grad_norm', grad_norm)]) - - # Plot Learning Stats - tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step) - tb.add_scalar('Loss/LinearLoss', linear_loss.data[0], - current_step) - tb.add_scalar('Loss/MelLoss', mel_loss.data[0], current_step) - tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'], - current_step) - tb.add_scalar('Params/GradNorm', grad_norm, current_step) - tb.add_scalar('Time/StepTime', step_time, current_step) - - align_img = alignments[0].data.cpu().numpy() - align_img = plot_alignment(align_img) - tb.add_image('Attn/Alignment', align_img, current_step) - - if current_step % c.save_step == 0: - - if c.checkpoint: - # save model - save_checkpoint(model, optimizer, linear_loss.data[0], - OUT_PATH, current_step, epoch) - - # Diagnostic visualizations - const_spec = linear_output[0].data.cpu().numpy() - gt_spec = linear_spec_var[0].data.cpu().numpy() - - const_spec = plot_spectrogram(const_spec, dataset.ap) - gt_spec = plot_spectrogram(gt_spec, dataset.ap) - tb.add_image('Spec/Reconstruction', const_spec, current_step) - tb.add_image('Spec/GroundTruth', gt_spec, current_step) - - align_img = alignments[0].data.cpu().numpy() - align_img = plot_alignment(align_img) - tb.add_image('Attn/Alignment', align_img, current_step) - - # Sample audio - audio_signal = linear_output[0].data.cpu().numpy() - dataset.ap.griffin_lim_iters = 60 - audio_signal = dataset.ap.inv_spectrogram(audio_signal.T) - try: - tb.add_audio('SampleAudio', audio_signal, current_step, - sample_rate=c.sample_rate) - except: - print("\n > Error at audio signal on TB!!") - print(audio_signal.max()) - print(audio_signal.min()) - - - # average loss after the epoch - avg_epoch_loss = np.mean( - progbar.sum_values['linear_loss'][0] / max(1, progbar.sum_values['linear_loss'][1])) - best_loss = save_best_model(model, optimizer, avg_epoch_loss, + train_loss, current_step = train(model, criterion, train_loader, optimizer, epoch) + val_loss = evaluate(model, criterion, val_loader, current_step) + best_loss = save_best_model(model, optimizer, val_loss, best_loss, OUT_PATH, current_step, epoch) - tb.add_scalar('Time/EpochTime', epoch_time, epoch) - epoch_time = 0 - - if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--restore_step', type=int, - help='Global step to restore checkpoint', default=0) - parser.add_argument('--restore_path', type=str, - help='Folder path to checkpoints', default=0) - parser.add_argument('--config_path', type=str, - help='path to config file for training',) - args = parser.parse_args() + signal.signal(signal.SIGINT, signal_handler) main(args) From 851718751131fb8b17906659cde3e3c2e8112fd8 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 2 Mar 2018 08:01:04 -0800 Subject: [PATCH 09/13] Run ready --- config.json | 2 +- datasets/LJSpeech.py | 1 - train.py | 2 ++ 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/config.json b/config.json index cd3aef72..0ad2921b 100644 --- a/config.json +++ b/config.json @@ -25,5 +25,5 @@ "checkpoint": false, "save_step": 69, "data_path": "/run/shm/erogol/LJSpeech-1.0", - "output_path": "result", + "output_path": "result" } diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index ded16ed5..334047a1 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -18,7 +18,6 @@ class LJSpeechDataset(Dataset): with open(csv_file, "r") as f: self.frames = [line.split('|') for line in f] - self.frames = self.frames[:256] self.root_dir = root_dir self.outputs_per_step = outputs_per_step self.sample_rate = sample_rate diff --git a/train.py b/train.py index c806f965..53c5698d 100644 --- a/train.py +++ b/train.py @@ -285,6 +285,7 @@ def evaluate(model, criterion, data_loader, current_step): def main(args): + # Setup the dataset train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'), os.path.join(c.data_path, 'wavs'), c.r, @@ -325,6 +326,7 @@ def main(args): drop_last=True, num_workers= 4, pin_memory=True) + model = Tacotron(c.embedding_size, c.hidden_size, c.num_mels, c.num_freq, From 3888b31b3c835a9a19656e7d452d003fec6d4488 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 5 Mar 2018 08:48:17 -0800 Subject: [PATCH 10/13] bug fix --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 53c5698d..fc6680a7 100644 --- a/train.py +++ b/train.py @@ -345,14 +345,14 @@ def main(args): model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n > Model restored from step %d\n" % args.restore_step) - start_epoch = checkpoint['step'] // len(dataloader) + start_epoch = checkpoint['step'] // len(train_loader) best_loss = checkpoint['linear_loss'] elif args.restore_path: checkpoint = torch.load(args.restore_path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n > Model restored from step %d\n" % checkpoint['step']) - start_epoch = checkpoint['step'] // len(dataloader) + start_epoch = checkpoint['step'] // len(train_loader) best_loss = checkpoint['linear_loss'] start_epoch = 0 args.restore_step = checkpoint['step'] From a2a2065bb40ddb6c5ac625eb79f2bdd5a9539b2c Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 5 Mar 2018 08:54:23 -0800 Subject: [PATCH 11/13] bug fix --- train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index fc6680a7..69df193e 100644 --- a/train.py +++ b/train.py @@ -159,8 +159,8 @@ def train(model, criterion, data_loader, optimizer, epoch): const_spec = linear_output[0].data.cpu().numpy() gt_spec = linear_spec_var[0].data.cpu().numpy() - const_spec = plot_spectrogram(const_spec, dataset.ap) - gt_spec = plot_spectrogram(gt_spec, dataset.ap) + const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap) + gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap) tb.add_image('Visual/Reconstruction', const_spec, current_step) tb.add_image('Visual/GroundTruth', gt_spec, current_step) @@ -170,8 +170,8 @@ def train(model, criterion, data_loader, optimizer, epoch): # Sample audio audio_signal = linear_output[0].data.cpu().numpy() - dataset.ap.griffin_lim_iters = 60 - audio_signal = dataset.ap.inv_spectrogram(audio_signal.T) + data_loader.dataset.ap.griffin_lim_iters = 60 + audio_signal = data_loader.dataset.ap.inv_spectrogram(audio_signal.T) try: tb.add_audio('SampleAudio', audio_signal, current_step, sample_rate=c.sample_rate) From 405fbc434e3e8938a7d09d4046ce46c6d18082c4 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 6 Mar 2018 05:39:54 -0800 Subject: [PATCH 12/13] small clean --- train.py | 41 ++++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/train.py b/train.py index 69df193e..6e7c7726 100644 --- a/train.py +++ b/train.py @@ -31,8 +31,6 @@ from models.tacotron import Tacotron use_cuda = torch.cuda.is_available() parser = argparse.ArgumentParser() -parser.add_argument('--restore_step', type=int, - help='Global step to restore checkpoint', default=0) parser.add_argument('--restore_path', type=str, help='Folder path to checkpoints', default=0) parser.add_argument('--config_path', type=str, @@ -67,6 +65,8 @@ def signal_handler(signal, frame): def train(model, criterion, data_loader, optimizer, epoch): model = model.train() epoch_time = 0 + avg_linear_loss = 0 + avg_mel_loss = 0 print(" | > Epoch {}/{}".format(epoch, c.epochs)) progbar = Progbar(len(data_loader.dataset) / c.batch_size) @@ -180,11 +180,10 @@ def train(model, criterion, data_loader, optimizer, epoch): print(audio_signal.max()) print(audio_signal.min()) - avg_linear_loss = np.mean( - progbar.sum_values['linear_loss'][0] / max(1, progbar.sum_values['linear_loss'][1])) - avg_mel_loss = np.mean( - progbar.sum_values['mel_loss'][0] / max(1, progbar.sum_values['mel_loss'][1])) - avg_total_loss = avg_mel_loss + avg_linear_loss + + avg_linear_loss /= (num_iter + 1) + avg_mel_loss /= (num_iter + 1) + avg_total_loss = avg_mel_loss + avg_linear_loss # Plot Training Epoch Stats tb.add_scalar('TrainEpochLoss/TotalLoss', loss.data[0], current_step) @@ -203,6 +202,9 @@ def evaluate(model, criterion, data_loader, current_step): print("\n | > Validation") n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) progbar = Progbar(len(data_loader.dataset) / c.batch_size) + + avg_linear_loss = 0 + avg_mel_loss = 0 for num_iter, data in enumerate(data_loader): start_time = time.time() @@ -242,19 +244,22 @@ def evaluate(model, criterion, data_loader, current_step): progbar.update(num_iter+1, values=[('total_loss', loss.data[0]), ('linear_loss', linear_loss.data[0]), ('mel_loss', mel_loss.data[0])]) + + avg_linear_loss += linear_loss.data[0] + avg_mel_loss += avg_mel_loss.data[0] # Diagnostic visualizations idx = np.random.randint(c.batch_size) const_spec = linear_output[idx].data.cpu().numpy() gt_spec = linear_spec_var[idx].data.cpu().numpy() + align_img = alignments[idx].data.cpu().numpy() const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap) gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap) + align_img = plot_alignment(align_img) + tb.add_image('ValVisual/Reconstruction', const_spec, current_step) tb.add_image('ValVisual/GroundTruth', gt_spec, current_step) - - align_img = alignments[idx].data.cpu().numpy() - align_img = plot_alignment(align_img) tb.add_image('ValVisual/ValidationAlignment', align_img, current_step) # Sample audio @@ -270,10 +275,8 @@ def evaluate(model, criterion, data_loader, current_step): print(audio_signal.min()) # compute average losses - avg_linear_loss = np.mean( - progbar.sum_values['linear_loss'][0] / max(1, progbar.sum_values['linear_loss'][1])) - avg_mel_loss = np.mean( - progbar.sum_values['mel_loss'][0] / max(1, progbar.sum_values['mel_loss'][1])) + avg_linear_loss /= (num_iter + 1) + avg_mel_loss /= (num_iter + 1) avg_total_loss = avg_mel_loss + avg_linear_loss # Plot Learning Stats @@ -339,15 +342,7 @@ def main(args): else: criterion = nn.L1Loss() - if args.restore_step: - checkpoint = torch.load(os.path.join( - args.restore_path, 'checkpoint_%d.pth.tar' % args.restore_step)) - model.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - print("\n > Model restored from step %d\n" % args.restore_step) - start_epoch = checkpoint['step'] // len(train_loader) - best_loss = checkpoint['linear_loss'] - elif args.restore_path: + if args.restore_path: checkpoint = torch.load(args.restore_path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) From b4032e8dffc2addf976468826129edea2b459ead Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 7 Mar 2018 06:58:51 -0800 Subject: [PATCH 13/13] best model ever changes --- config.json | 2 +- datasets/LJSpeech.py | 23 ++++++++++++++++ layers/attention.py | 65 +++++++++++++++++++++----------------------- layers/tacotron.py | 41 ++++++++-------------------- models/tacotron.py | 8 ++---- train.py | 20 ++++++++------ 6 files changed, 80 insertions(+), 79 deletions(-) diff --git a/config.json b/config.json index 0ad2921b..838f1510 100644 --- a/config.json +++ b/config.json @@ -14,7 +14,7 @@ "epochs": 2000, "lr": 0.0006, "warmup_steps": 4000, - "batch_size": 180, + "batch_size": 32, "r": 5, "griffin_lim_iters": 60, diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index 334047a1..a42a626e 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -26,6 +26,7 @@ class LJSpeechDataset(Dataset): frame_length_ms, preemphasis, ref_level_db, num_freq, power) print(" > Reading LJSpeech from - {}".format(root_dir)) print(" | > Number of instances : {}".format(len(self.frames))) + self._sort_frames() def load_wav(self, filename): try: @@ -34,6 +35,20 @@ class LJSpeechDataset(Dataset): except RuntimeError as e: print(" !! Cannot read file : {}".format(filename)) + def _sort_frames(self): + r"""Sort sequences in ascending order""" + lengths = np.array([len(ins[1]) for ins in self.frames]) + + print(" | > Max length sequence {}".format(np.max(lengths))) + print(" | > Min length sequence {}".format(np.min(lengths))) + print(" | > Avg length sequence {}".format(np.mean(lengths))) + + idxs = np.argsort(lengths) + new_frames = [None] * len(lengths) + for i, idx in enumerate(idxs): + new_frames[i] = self.frames[idx] + self.frames = new_frames + def __len__(self): return len(self.frames) @@ -47,9 +62,17 @@ class LJSpeechDataset(Dataset): return sample def get_dummy_data(self): + r"""Get a dummy input for testing""" return torch.autograd.Variable(torch.ones(16, 143)).type(torch.LongTensor) def collate_fn(self, batch): + r""" + Perform preprocessing and create a final data batch: + 1. PAD sequences with the longest sequence in the batch + 2. Convert Audio signal to Spectrograms. + 3. PAD sequences that can be divided by r. + 4. Convert Numpy to Torch tensors. + """ # Puts each data field into a tensor with outer dimension batch size if isinstance(batch[0], collections.Mapping): diff --git a/layers/attention.py b/layers/attention.py index 8d993cea..e7385149 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -5,26 +5,27 @@ from torch.nn import functional as F class BahdanauAttention(nn.Module): - def __init__(self, dim): + def __init__(self, annot_dim, query_dim, hidden_dim): super(BahdanauAttention, self).__init__() - self.query_layer = nn.Linear(dim, dim, bias=False) - self.tanh = nn.Tanh() - self.v = nn.Linear(dim, 1, bias=False) + self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True) + self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True) + self.v = nn.Linear(hidden_dim, 1, bias=False) - def forward(self, query, processed_inputs): + def forward(self, annots, query): """ - Args: - query: (batch, 1, dim) or (batch, dim) - processed_inputs: (batch, max_time, dim) + Shapes: + - query: (batch, 1, dim) or (batch, dim) + - annots: (batch, max_time, dim) """ if query.dim() == 2: # insert time-axis for broadcasting query = query.unsqueeze(1) # (batch, 1, dim) processed_query = self.query_layer(query) + processed_annots = self.annot_layer(annots) # (batch, max_time, 1) - alignment = self.v(self.tanh(processed_query + processed_inputs)) + alignment = self.v(nn.functional.tanh(processed_query + processed_annots)) # (batch, max_time) return alignment.squeeze(-1) @@ -34,7 +35,7 @@ def get_mask_from_lengths(inputs, inputs_lengths): """Get mask tensor from list of length Args: - inputs: (batch, max_time, dim) + inputs: Tensor in size (batch, max_time, dim) inputs_lengths: array like """ mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_() @@ -43,52 +44,48 @@ def get_mask_from_lengths(inputs, inputs_lengths): return ~mask -class AttentionWrapper(nn.Module): - def __init__(self, rnn_cell, alignment_model, +class AttentionRNN(nn.Module): + def __init__(self, out_dim, annot_dim, memory_dim, score_mask_value=-float("inf")): - super(AttentionWrapper, self).__init__() - self.rnn_cell = rnn_cell - self.alignment_model = alignment_model + super(AttentionRNN, self).__init__() + self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, out_dim) + self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim) self.score_mask_value = score_mask_value - def forward(self, query, context_vec, cell_state, inputs, - processed_inputs=None, mask=None, inputs_lengths=None): + def forward(self, memory, context, rnn_state, annotations, + mask=None, annotations_lengths=None): - if processed_inputs is None: - processed_inputs = inputs - - if inputs_lengths is not None and mask is None: - mask = get_mask_from_lengths(inputs, inputs_lengths) + if annotations_lengths is not None and mask is None: + mask = get_mask_from_lengths(annotations, annotations_lengths) # Alignment # (batch, max_time) # e_{ij} = a(s_{i-1}, h_j) - # import ipdb - # ipdb.set_trace() - alignment = self.alignment_model(cell_state, processed_inputs) + alignment = self.alignment_model(annotations, rnn_state) + # TODO: needs recheck. if mask is not None: mask = mask.view(query.size(0), -1) alignment.data.masked_fill_(mask, self.score_mask_value) - # Normalize context_vec weight + # Normalize context weight alignment = F.softmax(alignment, dim=-1) # Attention context vector # (batch, 1, dim) # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j - context_vec = torch.bmm(alignment.unsqueeze(1), inputs) - context_vec = context_vec.squeeze(1) + context = torch.bmm(alignment.unsqueeze(1), annotations) + context = context.squeeze(1) - # Concat input query and previous context_vec context - cell_input = torch.cat((query, context_vec), -1) - #cell_input = cell_input.unsqueeze(1) + # Concat input query and previous context context + rnn_input = torch.cat((memory, context), -1) + #rnn_input = rnn_input.unsqueeze(1) # Feed it to RNN # s_i = f(y_{i-1}, c_{i}, s_{i-1}) - cell_output = self.rnn_cell(cell_input, cell_state) + rnn_output = self.rnn_cell(rnn_input, rnn_state) - context_vec = context_vec.squeeze(1) - return cell_output, context_vec, alignment + context = context.squeeze(1) + return rnn_output, context, alignment diff --git a/layers/tacotron.py b/layers/tacotron.py index ac348017..6f5926a8 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -3,7 +3,7 @@ import torch from torch.autograd import Variable from torch import nn -from .attention import BahdanauAttention, AttentionWrapper +from .attention import AttentionRNN from .attention import get_mask_from_lengths class Prenet(nn.Module): @@ -219,15 +219,10 @@ class Decoder(nn.Module): self.memory_dim = memory_dim self.eps = eps self.r = r - # input -> |Linear| -> processed_inputs - self.input_layer = nn.Linear(in_features, 256, bias=False) # memory -> |Prenet| -> processed_memory self.prenet = Prenet(memory_dim * r, out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State - self.attention_rnn = AttentionWrapper( - nn.GRUCell(in_features + 128, 256), - BahdanauAttention(256) - ) + self.attention_rnn = AttentionRNN(256, in_features, 128) # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(256+in_features, 256) # decoder_RNN_input -> |RNN| -> RNN_state @@ -245,9 +240,9 @@ class Decoder(nn.Module): Args: inputs: Encoder outputs. - memory: Decoder memory (autoregression. If None (at eval-time), + memory (None): Decoder memory (autoregression. If None (at eval-time), decoder outputs are used as decoder inputs. - input_lengths: Encoder output (memory) lengths. If not None, used for + input_lengths (None): input lengths, used for attention masking. Shapes: @@ -256,12 +251,11 @@ class Decoder(nn.Module): """ B = inputs.size(0) - # TODO: take this segment into Attention module. - processed_inputs = self.input_layer(inputs) - if input_lengths is not None: - mask = get_mask_from_lengths(processed_inputs, input_lengths) - else: - mask = None + + # if input_lengths is not None: + # mask = get_mask_from_lengths(processed_inputs, input_lengths) + # else: + # mask = None # Run greedy decoding if memory is None greedy = memory is None @@ -300,20 +294,7 @@ class Decoder(nn.Module): memory_input = initial_memory while True: if t > 0: - # using harmonized teacher-forcing. - # from https://arxiv.org/abs/1707.06588 - if greedy: - memory_input = outputs[-1] - else: - # combine prev. model output and prev. real target - memory_input = torch.div(outputs[-1] + memory[t-1], 2.0) - memory_input = torch.nn.functional.dropout(memory_input, - 0.1, - training=True) - # add a random noise - noise = torch.autograd.Variable( - memory_input.data.new(memory_input.size()).normal_(0.0, 1.0)) - memory_input = memory_input + noise + memory_input = outputs[-1] if greedy else memory[t - 1] # Prenet processed_memory = self.prenet(memory_input) @@ -321,7 +302,7 @@ class Decoder(nn.Module): # Attention RNN attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn( processed_memory, current_context_vec, attention_rnn_hidden, - inputs, processed_inputs=processed_inputs, mask=mask) + inputs) # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( diff --git a/models/tacotron.py b/models/tacotron.py index c6218e40..57c9b43d 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -9,11 +9,11 @@ from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG class Tacotron(nn.Module): def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80, freq_dim=1025, r=5, padding_idx=None, - use_memory_mask=False): + use_atten_mask=False): super(Tacotron, self).__init__() self.mel_dim = mel_dim self.linear_dim = linear_dim - self.use_memory_mask = use_memory_mask + self.use_atten_mask = use_atten_mask self.embedding = nn.Embedding(len(symbols), embedding_dim, padding_idx=padding_idx) print(" | > Embedding dim : {}".format(len(symbols))) @@ -33,9 +33,7 @@ class Tacotron(nn.Module): # (B, T', in_dim) encoder_outputs = self.encoder(inputs) - if self.use_memory_mask: - input_lengths = input_lengths - else: + if not self.use_atten_mask: input_lengths = None # (B, T', mel_dim*r) diff --git a/train.py b/train.py index 6e7c7726..8aa6567d 100644 --- a/train.py +++ b/train.py @@ -199,7 +199,7 @@ def evaluate(model, criterion, data_loader, current_step): model = model.train() epoch_time = 0 - print("\n | > Validation") + print(" | > Validation") n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) progbar = Progbar(len(data_loader.dataset) / c.batch_size) @@ -246,10 +246,10 @@ def evaluate(model, criterion, data_loader, current_step): ('mel_loss', mel_loss.data[0])]) avg_linear_loss += linear_loss.data[0] - avg_mel_loss += avg_mel_loss.data[0] + avg_mel_loss += mel_loss.data[0] # Diagnostic visualizations - idx = np.random.randint(c.batch_size) + idx = np.random.randint(mel_input.shape[0]) const_spec = linear_output[idx].data.cpu().numpy() gt_spec = linear_spec_var[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy() @@ -270,7 +270,7 @@ def evaluate(model, criterion, data_loader, current_step): tb.add_audio('ValSampleAudio', audio_signal, current_step, sample_rate=c.sample_rate) except: - print("\n > Error at audio signal on TB!!") + print(" | > Error at audio signal on TB!!") print(audio_signal.max()) print(audio_signal.min()) @@ -305,8 +305,8 @@ def main(args): ) train_loader = DataLoader(train_dataset, batch_size=c.batch_size, - shuffle=True, collate_fn=train_dataset.collate_fn, - drop_last=True, num_workers=c.num_loader_workers, + shuffle=False, collate_fn=train_dataset.collate_fn, + drop_last=False, num_workers=c.num_loader_workers, pin_memory=True) val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'), @@ -325,15 +325,16 @@ def main(args): ) val_loader = DataLoader(val_dataset, batch_size=c.batch_size, - shuffle=True, collate_fn=val_dataset.collate_fn, - drop_last=True, num_workers= 4, + shuffle=False, collate_fn=val_dataset.collate_fn, + drop_last=False, num_workers= 4, pin_memory=True) model = Tacotron(c.embedding_size, c.hidden_size, c.num_mels, c.num_freq, - c.r) + c.r, + use_atten_mask=True) optimizer = optim.Adam(model.parameters(), lr=c.lr) @@ -352,6 +353,7 @@ def main(args): start_epoch = 0 args.restore_step = checkpoint['step'] else: + args.restore_step = 0 print("\n > Starting a new training") if use_cuda: