mirror of https://github.com/coqui-ai/TTS.git
Mask inputs by length to reduce the effetc on attention module
This commit is contained in:
parent
4e0ab65bbf
commit
966d567540
|
@ -3,6 +3,7 @@ import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import collections
|
import collections
|
||||||
import librosa
|
import librosa
|
||||||
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from TTS.utils.text import text_to_sequence
|
from TTS.utils.text import text_to_sequence
|
||||||
|
@ -45,6 +46,9 @@ class LJSpeechDataset(Dataset):
|
||||||
sample = {'text': text, 'wav': wav}
|
sample = {'text': text, 'wav': wav}
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
def get_dummy_data(self):
|
||||||
|
return torch.autograd.Variable(torch.ones(16, 143)).type(torch.LongTensor)
|
||||||
|
|
||||||
def collate_fn(self, batch):
|
def collate_fn(self, batch):
|
||||||
|
|
||||||
# Puts each data field into a tensor with outer dimension batch size
|
# Puts each data field into a tensor with outer dimension batch size
|
||||||
|
@ -73,7 +77,7 @@ class LJSpeechDataset(Dataset):
|
||||||
magnitude = magnitude.transpose(0, 2, 1)
|
magnitude = magnitude.transpose(0, 2, 1)
|
||||||
mel = mel.transpose(0, 2, 1)
|
mel = mel.transpose(0, 2, 1)
|
||||||
|
|
||||||
return text, magnitude, mel
|
return text, text_lenghts, magnitude, mel
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
found {}"
|
found {}"
|
||||||
|
|
35
train.py
35
train.py
|
@ -12,6 +12,7 @@ import numpy as np
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import optim
|
from torch import optim
|
||||||
|
from torch import onnx
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
|
@ -52,6 +53,7 @@ def main(args):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
# Setup the dataset
|
||||||
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||||
os.path.join(c.data_path, 'wavs'),
|
os.path.join(c.data_path, 'wavs'),
|
||||||
c.r,
|
c.r,
|
||||||
|
@ -67,11 +69,25 @@ def main(args):
|
||||||
c.power
|
c.power
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
||||||
|
shuffle=True, collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=True, num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
|
# setup the model
|
||||||
model = Tacotron(c.embedding_size,
|
model = Tacotron(c.embedding_size,
|
||||||
c.hidden_size,
|
c.hidden_size,
|
||||||
c.num_mels,
|
c.num_mels,
|
||||||
c.num_freq,
|
c.num_freq,
|
||||||
c.r)
|
c.r)
|
||||||
|
|
||||||
|
# plot model on tensorboard
|
||||||
|
dummy_input = dataset.get_dummy_data()
|
||||||
|
|
||||||
|
## TODO: onnx does not support RNN fully yet
|
||||||
|
# model_proto_path = os.path.join(OUT_PATH, "model.proto")
|
||||||
|
# onnx.export(model, dummy_input, model_proto_path, verbose=True)
|
||||||
|
# tb.add_graph_onnx(model_proto_path)
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model = nn.DataParallel(model.cuda())
|
model = nn.DataParallel(model.cuda())
|
||||||
|
|
||||||
|
@ -105,9 +121,6 @@ def main(args):
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
for epoch in range(c.epochs):
|
for epoch in range(c.epochs):
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
|
||||||
shuffle=True, collate_fn=dataset.collate_fn,
|
|
||||||
drop_last=True, num_workers=c.num_loader_workers)
|
|
||||||
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
|
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
|
||||||
progbar = Progbar(len(dataset) / c.batch_size)
|
progbar = Progbar(len(dataset) / c.batch_size)
|
||||||
|
|
||||||
|
@ -115,8 +128,9 @@ def main(args):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
magnitude_input = data[1]
|
text_lengths = data[1]
|
||||||
mel_input = data[2]
|
magnitude_input = data[2]
|
||||||
|
mel_input = data[3]
|
||||||
|
|
||||||
current_step = i + args.restore_step + epoch * len(dataloader) + 1
|
current_step = i + args.restore_step + epoch * len(dataloader) + 1
|
||||||
|
|
||||||
|
@ -137,6 +151,8 @@ def main(args):
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input_var = Variable(torch.from_numpy(text_input).type(
|
text_input_var = Variable(torch.from_numpy(text_input).type(
|
||||||
torch.cuda.LongTensor)).cuda()
|
torch.cuda.LongTensor)).cuda()
|
||||||
|
text_lengths_var = Variable(torch.from_numpy(test_lengths).type(
|
||||||
|
torch.cuda.LongTensor)).cuda()
|
||||||
mel_input_var = Variable(torch.from_numpy(mel_input).type(
|
mel_input_var = Variable(torch.from_numpy(mel_input).type(
|
||||||
torch.cuda.FloatTensor)).cuda()
|
torch.cuda.FloatTensor)).cuda()
|
||||||
mel_spec_var = Variable(torch.from_numpy(mel_input).type(
|
mel_spec_var = Variable(torch.from_numpy(mel_input).type(
|
||||||
|
@ -147,6 +163,8 @@ def main(args):
|
||||||
else:
|
else:
|
||||||
text_input_var = Variable(torch.from_numpy(text_input).type(
|
text_input_var = Variable(torch.from_numpy(text_input).type(
|
||||||
torch.LongTensor),)
|
torch.LongTensor),)
|
||||||
|
text_lengths_var = Variable(torch.from_numpy(test_lengths).type(
|
||||||
|
torch.LongTensor))
|
||||||
mel_input_var = Variable(torch.from_numpy(mel_input).type(
|
mel_input_var = Variable(torch.from_numpy(mel_input).type(
|
||||||
torch.FloatTensor))
|
torch.FloatTensor))
|
||||||
mel_spec_var = Variable(torch.from_numpy(
|
mel_spec_var = Variable(torch.from_numpy(
|
||||||
|
@ -155,7 +173,7 @@ def main(args):
|
||||||
magnitude_input).type(torch.FloatTensor))
|
magnitude_input).type(torch.FloatTensor))
|
||||||
|
|
||||||
mel_output, linear_output, alignments =\
|
mel_output, linear_output, alignments =\
|
||||||
model.forward(text_input_var, mel_input_var)
|
model.forward(text_input_var, mel_input_var, input_lengths=input_lengths_var)
|
||||||
|
|
||||||
mel_loss = criterion(mel_output, mel_spec_var)
|
mel_loss = criterion(mel_output, mel_spec_var)
|
||||||
#linear_loss = torch.abs(linear_output - linear_spec_var)
|
#linear_loss = torch.abs(linear_output - linear_spec_var)
|
||||||
|
@ -169,7 +187,7 @@ def main(args):
|
||||||
# loss = loss.cuda()
|
# loss = loss.cuda()
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
grad_norm = nn.utils.clip_grad_norm(model.parameters(), 1.)
|
grad_norm = nn.utils.clip_grad_norm(model.parameters(), 1.) ## TODO: maybe no need
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
|
@ -180,11 +198,12 @@ def main(args):
|
||||||
('mel_loss', mel_loss.data[0]),
|
('mel_loss', mel_loss.data[0]),
|
||||||
('grad_norm', grad_norm)])
|
('grad_norm', grad_norm)])
|
||||||
|
|
||||||
|
|
||||||
|
# Plot Learning Stats
|
||||||
tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)
|
tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)
|
||||||
tb.add_scalar('Loss/LinearLoss', linear_loss.data[0],
|
tb.add_scalar('Loss/LinearLoss', linear_loss.data[0],
|
||||||
current_step)
|
current_step)
|
||||||
tb.add_scalar('Loss/MelLoss', mel_loss.data[0], current_step)
|
tb.add_scalar('Loss/MelLoss', mel_loss.data[0], current_step)
|
||||||
|
|
||||||
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
|
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
|
||||||
current_step)
|
current_step)
|
||||||
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
||||||
|
|
Loading…
Reference in New Issue