diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index f95e4d92..b5e0d092 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -3,6 +3,7 @@ import os import numpy as np import collections import librosa +import torch from torch.utils.data import Dataset from TTS.utils.text import text_to_sequence @@ -45,6 +46,9 @@ class LJSpeechDataset(Dataset): sample = {'text': text, 'wav': wav} return sample + def get_dummy_data(self): + return torch.autograd.Variable(torch.ones(16, 143)).type(torch.LongTensor) + def collate_fn(self, batch): # Puts each data field into a tensor with outer dimension batch size @@ -73,7 +77,7 @@ class LJSpeechDataset(Dataset): magnitude = magnitude.transpose(0, 2, 1) mel = mel.transpose(0, 2, 1) - return text, magnitude, mel + return text, text_lenghts, magnitude, mel raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ found {}" diff --git a/train.py b/train.py index 89e2f825..ed44bb4f 100644 --- a/train.py +++ b/train.py @@ -12,6 +12,7 @@ import numpy as np import torch.nn as nn from torch import optim +from torch import onnx from torch.autograd import Variable from torch.utils.data import DataLoader from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -52,6 +53,7 @@ def main(args): sys.exit(1) signal.signal(signal.SIGINT, signal_handler) + # Setup the dataset dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'), os.path.join(c.data_path, 'wavs'), c.r, @@ -67,11 +69,25 @@ def main(args): c.power ) + dataloader = DataLoader(dataset, batch_size=c.batch_size, + shuffle=True, collate_fn=dataset.collate_fn, + drop_last=True, num_workers=c.num_loader_workers) + + # setup the model model = Tacotron(c.embedding_size, c.hidden_size, c.num_mels, c.num_freq, c.r) + + # plot model on tensorboard + dummy_input = dataset.get_dummy_data() + + ## TODO: onnx does not support RNN fully yet + # model_proto_path = os.path.join(OUT_PATH, "model.proto") + # onnx.export(model, dummy_input, model_proto_path, verbose=True) + # tb.add_graph_onnx(model_proto_path) + if use_cuda: model = nn.DataParallel(model.cuda()) @@ -105,9 +121,6 @@ def main(args): epoch_time = 0 for epoch in range(c.epochs): - dataloader = DataLoader(dataset, batch_size=c.batch_size, - shuffle=True, collate_fn=dataset.collate_fn, - drop_last=True, num_workers=c.num_loader_workers) print("\n | > Epoch {}/{}".format(epoch, c.epochs)) progbar = Progbar(len(dataset) / c.batch_size) @@ -115,8 +128,9 @@ def main(args): start_time = time.time() text_input = data[0] - magnitude_input = data[1] - mel_input = data[2] + text_lengths = data[1] + magnitude_input = data[2] + mel_input = data[3] current_step = i + args.restore_step + epoch * len(dataloader) + 1 @@ -137,6 +151,8 @@ def main(args): if use_cuda: text_input_var = Variable(torch.from_numpy(text_input).type( torch.cuda.LongTensor)).cuda() + text_lengths_var = Variable(torch.from_numpy(test_lengths).type( + torch.cuda.LongTensor)).cuda() mel_input_var = Variable(torch.from_numpy(mel_input).type( torch.cuda.FloatTensor)).cuda() mel_spec_var = Variable(torch.from_numpy(mel_input).type( @@ -147,6 +163,8 @@ def main(args): else: text_input_var = Variable(torch.from_numpy(text_input).type( torch.LongTensor),) + text_lengths_var = Variable(torch.from_numpy(test_lengths).type( + torch.LongTensor)) mel_input_var = Variable(torch.from_numpy(mel_input).type( torch.FloatTensor)) mel_spec_var = Variable(torch.from_numpy( @@ -155,7 +173,7 @@ def main(args): magnitude_input).type(torch.FloatTensor)) mel_output, linear_output, alignments =\ - model.forward(text_input_var, mel_input_var) + model.forward(text_input_var, mel_input_var, input_lengths=input_lengths_var) mel_loss = criterion(mel_output, mel_spec_var) #linear_loss = torch.abs(linear_output - linear_spec_var) @@ -169,7 +187,7 @@ def main(args): # loss = loss.cuda() loss.backward() - grad_norm = nn.utils.clip_grad_norm(model.parameters(), 1.) + grad_norm = nn.utils.clip_grad_norm(model.parameters(), 1.) ## TODO: maybe no need optimizer.step() step_time = time.time() - start_time @@ -180,11 +198,12 @@ def main(args): ('mel_loss', mel_loss.data[0]), ('grad_norm', grad_norm)]) + + # Plot Learning Stats tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step) tb.add_scalar('Loss/LinearLoss', linear_loss.data[0], current_step) tb.add_scalar('Loss/MelLoss', mel_loss.data[0], current_step) - tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'], current_step) tb.add_scalar('Params/GradNorm', grad_norm, current_step)