mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'normal-attention+masked-loss'
This commit is contained in:
commit
b1ade13ff4
13
config.json
13
config.json
|
@ -7,25 +7,24 @@
|
||||||
"preemphasis": 0.97,
|
"preemphasis": 0.97,
|
||||||
"min_level_db": -100,
|
"min_level_db": -100,
|
||||||
"ref_level_db": 20,
|
"ref_level_db": 20,
|
||||||
"hidden_size": 128,
|
|
||||||
"embedding_size": 256,
|
"embedding_size": 256,
|
||||||
"text_cleaner": "english_cleaners",
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
"epochs": 2000,
|
"epochs": 2000,
|
||||||
"lr": 0.001,
|
"lr": 0.001,
|
||||||
"warmup_steps": 4000,
|
"warmup_steps": 4000,
|
||||||
"batch_size": 32,
|
"batch_size": 128,
|
||||||
"eval_batch_size": 32,
|
"eval_batch_size":32,
|
||||||
"r": 5,
|
"r": 5,
|
||||||
|
|
||||||
"griffin_lim_iters": 60,
|
"griffin_lim_iters": 60,
|
||||||
"power": 1.5,
|
"power": 1.5,
|
||||||
|
|
||||||
"num_loader_workers": 12,
|
"num_loader_workers": 8,
|
||||||
|
|
||||||
"checkpoint": false,
|
"checkpoint": true,
|
||||||
"save_step": 69,
|
"save_step": 378,
|
||||||
"data_path": "/run/shm/erogol/LJSpeech-1.0",
|
"data_path": "/run/shm/erogol/LJSpeech-1.0",
|
||||||
"min_seq_len": 0,
|
"min_seq_len": 0,
|
||||||
"output_path": "result"
|
"output_path": "/data/shared/erogol_models/"
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,8 @@ from torch.utils.data import Dataset
|
||||||
|
|
||||||
from TTS.utils.text import text_to_sequence
|
from TTS.utils.text import text_to_sequence
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.data import prepare_data, pad_data, pad_per_step
|
from TTS.utils.data import (prepare_data, pad_per_step,
|
||||||
|
prepare_tensor, prepare_stop_target)
|
||||||
|
|
||||||
|
|
||||||
class LJSpeechDataset(Dataset):
|
class LJSpeechDataset(Dataset):
|
||||||
|
@ -93,26 +94,27 @@ class LJSpeechDataset(Dataset):
|
||||||
text_lenghts = np.array([len(x) for x in text])
|
text_lenghts = np.array([len(x) for x in text])
|
||||||
max_text_len = np.max(text_lenghts)
|
max_text_len = np.max(text_lenghts)
|
||||||
|
|
||||||
|
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||||
|
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||||
|
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||||
|
|
||||||
|
# compute 'stop token' targets
|
||||||
|
stop_targets = [np.array([0.]*(mel_len-1)) for mel_len in mel_lengths]
|
||||||
|
|
||||||
|
# PAD stop targets
|
||||||
|
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
||||||
|
|
||||||
# PAD sequences with largest length of the batch
|
# PAD sequences with largest length of the batch
|
||||||
text = prepare_data(text).astype(np.int32)
|
text = prepare_data(text).astype(np.int32)
|
||||||
wav = prepare_data(wav)
|
wav = prepare_data(wav)
|
||||||
|
|
||||||
linear = np.array([self.ap.spectrogram(w).astype('float32') for w in wav])
|
# PAD features with largest length + a zero frame
|
||||||
mel = np.array([self.ap.melspectrogram(w).astype('float32') for w in wav])
|
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||||
|
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||||
assert mel.shape[2] == linear.shape[2]
|
assert mel.shape[2] == linear.shape[2]
|
||||||
timesteps = mel.shape[2]
|
timesteps = mel.shape[2]
|
||||||
|
|
||||||
# PAD with zeros that can be divided by outputs per step
|
# B x T x D
|
||||||
if (timesteps + 1) % self.outputs_per_step != 0:
|
|
||||||
pad_len = self.outputs_per_step - \
|
|
||||||
((timesteps + 1) % self.outputs_per_step)
|
|
||||||
pad_len += 1
|
|
||||||
else:
|
|
||||||
pad_len = 1
|
|
||||||
linear = pad_per_step(linear, pad_len)
|
|
||||||
mel = pad_per_step(mel, pad_len)
|
|
||||||
|
|
||||||
# reshape jombo
|
|
||||||
linear = linear.transpose(0, 2, 1)
|
linear = linear.transpose(0, 2, 1)
|
||||||
mel = mel.transpose(0, 2, 1)
|
mel = mel.transpose(0, 2, 1)
|
||||||
|
|
||||||
|
@ -121,7 +123,10 @@ class LJSpeechDataset(Dataset):
|
||||||
text = torch.LongTensor(text)
|
text = torch.LongTensor(text)
|
||||||
linear = torch.FloatTensor(linear)
|
linear = torch.FloatTensor(linear)
|
||||||
mel = torch.FloatTensor(mel)
|
mel = torch.FloatTensor(mel)
|
||||||
return text, text_lenghts, linear, mel, item_idxs[0]
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
|
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
found {}"
|
found {}"
|
||||||
|
|
Binary file not shown.
|
@ -48,7 +48,7 @@ class AttentionRNN(nn.Module):
|
||||||
def __init__(self, out_dim, annot_dim, memory_dim,
|
def __init__(self, out_dim, annot_dim, memory_dim,
|
||||||
score_mask_value=-float("inf")):
|
score_mask_value=-float("inf")):
|
||||||
super(AttentionRNN, self).__init__()
|
super(AttentionRNN, self).__init__()
|
||||||
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, out_dim)
|
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
||||||
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
|
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
|
||||||
self.score_mask_value = score_mask_value
|
self.score_mask_value = score_mask_value
|
||||||
|
|
||||||
|
@ -57,11 +57,19 @@ class AttentionRNN(nn.Module):
|
||||||
|
|
||||||
if annotations_lengths is not None and mask is None:
|
if annotations_lengths is not None and mask is None:
|
||||||
mask = get_mask_from_lengths(annotations, annotations_lengths)
|
mask = get_mask_from_lengths(annotations, annotations_lengths)
|
||||||
|
|
||||||
|
# Concat input query and previous context context
|
||||||
|
rnn_input = torch.cat((memory, context), -1)
|
||||||
|
#rnn_input = rnn_input.unsqueeze(1)
|
||||||
|
|
||||||
|
# Feed it to RNN
|
||||||
|
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
||||||
|
rnn_output = self.rnn_cell(rnn_input, rnn_state)
|
||||||
|
|
||||||
# Alignment
|
# Alignment
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
# e_{ij} = a(s_{i-1}, h_j)
|
# e_{ij} = a(s_{i-1}, h_j)
|
||||||
alignment = self.alignment_model(annotations, rnn_state)
|
alignment = self.alignment_model(annotations, rnn_output)
|
||||||
|
|
||||||
# TODO: needs recheck.
|
# TODO: needs recheck.
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
|
@ -75,16 +83,6 @@ class AttentionRNN(nn.Module):
|
||||||
# (batch, 1, dim)
|
# (batch, 1, dim)
|
||||||
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
||||||
context = torch.bmm(alignment.unsqueeze(1), annotations)
|
context = torch.bmm(alignment.unsqueeze(1), annotations)
|
||||||
context = context.squeeze(1)
|
|
||||||
|
|
||||||
# Concat input query and previous context context
|
|
||||||
rnn_input = torch.cat((memory, context), -1)
|
|
||||||
#rnn_input = rnn_input.unsqueeze(1)
|
|
||||||
|
|
||||||
# Feed it to RNN
|
|
||||||
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
|
||||||
rnn_output = self.rnn_cell(rnn_input, rnn_state)
|
|
||||||
|
|
||||||
context = context.squeeze(1)
|
context = context.squeeze(1)
|
||||||
return rnn_output, context, alignment
|
return rnn_output, context, alignment
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
# coding: utf-8
|
||||||
|
import torch
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
# class StopProjection(nn.Module):
|
||||||
|
# r""" Simple projection layer to predict the "stop token"
|
||||||
|
|
||||||
|
# Args:
|
||||||
|
# in_features (int): size of the input vector
|
||||||
|
# out_features (int or list): size of each output vector. aka number
|
||||||
|
# of predicted frames.
|
||||||
|
# """
|
||||||
|
|
||||||
|
# def __init__(self, in_features, out_features):
|
||||||
|
# super(StopProjection, self).__init__()
|
||||||
|
# self.linear = nn.Linear(in_features, out_features)
|
||||||
|
# self.dropout = nn.Dropout(0.5)
|
||||||
|
# self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
# def forward(self, inputs):
|
||||||
|
# out = self.dropout(inputs)
|
||||||
|
# out = self.linear(out)
|
||||||
|
# out = self.sigmoid(out)
|
||||||
|
# return out
|
|
@ -0,0 +1,57 @@
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||||
|
def _sequence_mask(sequence_length, max_len=None):
|
||||||
|
if max_len is None:
|
||||||
|
max_len = sequence_length.data.max()
|
||||||
|
batch_size = sequence_length.size(0)
|
||||||
|
seq_range = torch.arange(0, max_len).long()
|
||||||
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||||
|
seq_range_expand = Variable(seq_range_expand)
|
||||||
|
if sequence_length.is_cuda:
|
||||||
|
seq_range_expand = seq_range_expand.cuda()
|
||||||
|
seq_length_expand = (sequence_length.unsqueeze(1)
|
||||||
|
.expand_as(seq_range_expand))
|
||||||
|
return seq_range_expand < seq_length_expand
|
||||||
|
|
||||||
|
|
||||||
|
class L1LossMasked(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(L1LossMasked, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, input, target, length):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
logits: A Variable containing a FloatTensor of size
|
||||||
|
(batch, max_len, num_classes) which contains the
|
||||||
|
unnormalized probability for each class.
|
||||||
|
target: A Variable containing a LongTensor of size
|
||||||
|
(batch, max_len) which contains the index of the true
|
||||||
|
class for each corresponding step.
|
||||||
|
length: A Variable containing a LongTensor of size (batch,)
|
||||||
|
which contains the length of each data in a batch.
|
||||||
|
Returns:
|
||||||
|
loss: An average loss value masked by the length.
|
||||||
|
"""
|
||||||
|
input = input.contiguous()
|
||||||
|
target = target.contiguous()
|
||||||
|
|
||||||
|
# logits_flat: (batch * max_len, dim)
|
||||||
|
input = input.view(-1, input.size(-1))
|
||||||
|
# target_flat: (batch * max_len, dim)
|
||||||
|
target_flat = target.view(-1, 1)
|
||||||
|
# losses_flat: (batch * max_len, dim)
|
||||||
|
losses_flat = functional.l1_loss(input, target, size_average=False,
|
||||||
|
reduce=False)
|
||||||
|
# losses: (batch, max_len, dim)
|
||||||
|
losses = losses_flat.view(*target.size())
|
||||||
|
# mask: (batch, max_len, 1)
|
||||||
|
mask = _sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
||||||
|
losses = losses * mask.float()
|
||||||
|
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||||
|
return loss
|
|
@ -48,6 +48,7 @@ class BatchNormConv1d(nn.Module):
|
||||||
- input: batch x dims
|
- input: batch x dims
|
||||||
- output: batch x dims
|
- output: batch x dims
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
|
||||||
activation=None):
|
activation=None):
|
||||||
super(BatchNormConv1d, self).__init__()
|
super(BatchNormConv1d, self).__init__()
|
||||||
|
@ -213,8 +214,9 @@ class Decoder(nn.Module):
|
||||||
r (int): number of outputs per time step.
|
r (int): number of outputs per time step.
|
||||||
eps (float): threshold for detecting the end of a sentence.
|
eps (float): threshold for detecting the end of a sentence.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_features, memory_dim, r, eps=0.05):
|
def __init__(self, in_features, memory_dim, r, eps=0.05, mode='train'):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
|
self.mode = mode
|
||||||
self.max_decoder_steps = 200
|
self.max_decoder_steps = 200
|
||||||
self.memory_dim = memory_dim
|
self.memory_dim = memory_dim
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
@ -241,7 +243,8 @@ class Decoder(nn.Module):
|
||||||
Args:
|
Args:
|
||||||
inputs: Encoder outputs.
|
inputs: Encoder outputs.
|
||||||
memory (None): Decoder memory (autoregression. If None (at eval-time),
|
memory (None): Decoder memory (autoregression. If None (at eval-time),
|
||||||
decoder outputs are used as decoder inputs.
|
decoder outputs are used as decoder inputs. If None, it uses the last
|
||||||
|
output as the input.
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
- inputs: batch x time x encoder_out_dim
|
- inputs: batch x time x encoder_out_dim
|
||||||
|
@ -250,14 +253,13 @@ class Decoder(nn.Module):
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
|
|
||||||
# Run greedy decoding if memory is None
|
# Run greedy decoding if memory is None
|
||||||
greedy = memory is None
|
greedy = not self.training
|
||||||
|
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
|
|
||||||
# Grouping multiple frames if necessary
|
# Grouping multiple frames if necessary
|
||||||
if memory.size(-1) == self.memory_dim:
|
if memory.size(-1) == self.memory_dim:
|
||||||
memory = memory.view(B, memory.size(1) // self.r, -1)
|
memory = memory.view(B, memory.size(1) // self.r, -1)
|
||||||
assert memory.size(-1) == self.memory_dim * self.r,\
|
|
||||||
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
|
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
|
||||||
self.memory_dim, self.r)
|
self.memory_dim, self.r)
|
||||||
T_decoder = memory.size(1)
|
T_decoder = memory.size(1)
|
||||||
|
@ -286,15 +288,23 @@ class Decoder(nn.Module):
|
||||||
memory_input = initial_memory
|
memory_input = initial_memory
|
||||||
while True:
|
while True:
|
||||||
if t > 0:
|
if t > 0:
|
||||||
memory_input = outputs[-1] if greedy else memory[t - 1]
|
if greedy:
|
||||||
|
memory_input = outputs[-1]
|
||||||
|
else:
|
||||||
|
# combine prev. model output and prev. real target
|
||||||
|
# memory_input = torch.div(outputs[-1] + memory[t-1], 2.0)
|
||||||
|
# add a random noise
|
||||||
|
# noise = torch.autograd.Variable(
|
||||||
|
# memory_input.data.new(memory_input.size()).normal_(0.0, 0.5))
|
||||||
|
# memory_input = memory_input + noise
|
||||||
|
memory_input = memory[t-1]
|
||||||
|
|
||||||
# Prenet
|
# Prenet
|
||||||
processed_memory = self.prenet(memory_input)
|
processed_memory = self.prenet(memory_input)
|
||||||
|
|
||||||
# Attention RNN
|
# Attention RNN
|
||||||
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
|
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
|
||||||
processed_memory, current_context_vec, attention_rnn_hidden,
|
processed_memory, current_context_vec, attention_rnn_hidden, inputs)
|
||||||
inputs)
|
|
||||||
|
|
||||||
# Concat RNN output and attention context vector
|
# Concat RNN output and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
|
@ -306,8 +316,9 @@ class Decoder(nn.Module):
|
||||||
decoder_input, decoder_rnn_hiddens[idx])
|
decoder_input, decoder_rnn_hiddens[idx])
|
||||||
# Residual connectinon
|
# Residual connectinon
|
||||||
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
|
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
|
||||||
|
|
||||||
output = decoder_input
|
output = decoder_input
|
||||||
|
|
||||||
|
|
||||||
# predict mel vectors from decoder vectors
|
# predict mel vectors from decoder vectors
|
||||||
output = self.proj_to_mel(output)
|
output = self.proj_to_mel(output)
|
||||||
|
@ -317,17 +328,17 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
t += 1
|
t += 1
|
||||||
|
|
||||||
if greedy:
|
if (not greedy and self.training) or (greedy and memory is not None):
|
||||||
|
if t >= T_decoder:
|
||||||
|
break
|
||||||
|
else:
|
||||||
if t > 1 and is_end_of_frames(output, self.eps):
|
if t > 1 and is_end_of_frames(output, self.eps):
|
||||||
break
|
break
|
||||||
elif t > self.max_decoder_steps:
|
elif t > self.max_decoder_steps:
|
||||||
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
||||||
Something is probably wrong.")
|
Something is probably wrong.")
|
||||||
break
|
break
|
||||||
else:
|
|
||||||
if t >= T_decoder:
|
|
||||||
break
|
|
||||||
|
|
||||||
assert greedy or len(outputs) == T_decoder
|
assert greedy or len(outputs) == T_decoder
|
||||||
|
|
||||||
# Back to batch first
|
# Back to batch first
|
||||||
|
@ -338,4 +349,4 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def is_end_of_frames(output, eps=0.2): #0.2
|
def is_end_of_frames(output, eps=0.2): #0.2
|
||||||
return (output.data <= eps).all()
|
return (output.data <= eps).all()
|
Binary file not shown.
|
@ -8,9 +8,10 @@ from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG
|
||||||
|
|
||||||
class Tacotron(nn.Module):
|
class Tacotron(nn.Module):
|
||||||
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
|
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
|
||||||
freq_dim=1025, r=5, padding_idx=None):
|
r=5, padding_idx=None):
|
||||||
|
|
||||||
super(Tacotron, self).__init__()
|
super(Tacotron, self).__init__()
|
||||||
|
self.r = r
|
||||||
self.mel_dim = mel_dim
|
self.mel_dim = mel_dim
|
||||||
self.linear_dim = linear_dim
|
self.linear_dim = linear_dim
|
||||||
self.embedding = nn.Embedding(len(symbols), embedding_dim,
|
self.embedding = nn.Embedding(len(symbols), embedding_dim,
|
||||||
|
@ -23,9 +24,10 @@ class Tacotron(nn.Module):
|
||||||
self.decoder = Decoder(256, mel_dim, r)
|
self.decoder = Decoder(256, mel_dim, r)
|
||||||
|
|
||||||
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
|
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
|
||||||
self.last_linear = nn.Linear(mel_dim * 2, freq_dim)
|
self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
|
||||||
|
|
||||||
def forward(self, characters, mel_specs=None):
|
def forward(self, characters, mel_specs=None):
|
||||||
|
|
||||||
B = characters.size(0)
|
B = characters.size(0)
|
||||||
|
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
|
|
|
@ -2,6 +2,7 @@ import unittest
|
||||||
import torch as T
|
import torch as T
|
||||||
|
|
||||||
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
|
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
|
||||||
|
from layers.losses import L1LossMasked, _sequence_mask
|
||||||
|
|
||||||
|
|
||||||
class PrenetTests(unittest.TestCase):
|
class PrenetTests(unittest.TestCase):
|
||||||
|
@ -32,23 +33,22 @@ class CBHGTests(unittest.TestCase):
|
||||||
class DecoderTests(unittest.TestCase):
|
class DecoderTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Decoder(in_features=128, memory_dim=32, r=5)
|
layer = Decoder(in_features=256, memory_dim=80, r=2)
|
||||||
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
dummy_input = T.autograd.Variable(T.rand(4, 8, 256))
|
||||||
dummy_memory = T.autograd.Variable(T.rand(4, 120, 32))
|
dummy_memory = T.autograd.Variable(T.rand(4, 2, 80))
|
||||||
|
|
||||||
print(layer)
|
|
||||||
output, alignment = layer(dummy_input, dummy_memory)
|
output, alignment = layer(dummy_input, dummy_memory)
|
||||||
print(output.shape)
|
|
||||||
assert output.shape[0] == 4
|
assert output.shape[0] == 4
|
||||||
assert output.shape[1] == 120 / 5
|
assert output.shape[1] == 1, "size not {}".format(output.shape[1])
|
||||||
assert output.shape[2] == 32 * 5
|
assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2])
|
||||||
|
|
||||||
|
|
||||||
class EncoderTests(unittest.TestCase):
|
class EncoderTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Encoder(128)
|
layer = Encoder(128)
|
||||||
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
||||||
|
|
||||||
print(layer)
|
print(layer)
|
||||||
output = layer(dummy_input)
|
output = layer(dummy_input)
|
||||||
|
@ -56,4 +56,29 @@ class EncoderTests(unittest.TestCase):
|
||||||
assert output.shape[0] == 4
|
assert output.shape[0] == 4
|
||||||
assert output.shape[1] == 8
|
assert output.shape[1] == 8
|
||||||
assert output.shape[2] == 256 # 128 * 2 BiRNN
|
assert output.shape[2] == 256 # 128 * 2 BiRNN
|
||||||
|
|
||||||
|
|
||||||
|
class L1LossMaskedTests(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_in_out(self):
|
||||||
|
layer = L1LossMasked()
|
||||||
|
dummy_input = T.autograd.Variable(T.ones(4, 8, 128).float())
|
||||||
|
dummy_target = T.autograd.Variable(T.ones(4, 8, 128).float())
|
||||||
|
dummy_length = T.autograd.Variable((T.ones(4) * 8).long())
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.shape[0] == 1
|
||||||
|
assert len(output.shape) == 1
|
||||||
|
assert output.data[0] == 0.0
|
||||||
|
|
||||||
|
dummy_input = T.autograd.Variable(T.ones(4, 8, 128).float())
|
||||||
|
dummy_target = T.autograd.Variable(T.zeros(4, 8, 128).float())
|
||||||
|
dummy_length = T.autograd.Variable((T.ones(4) * 8).long())
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.data[0] == 1.0, "1.0 vs {}".format(output.data[0])
|
||||||
|
|
||||||
|
dummy_input = T.autograd.Variable(T.ones(4, 8, 128).float())
|
||||||
|
dummy_target = T.autograd.Variable(T.zeros(4, 8, 128).float())
|
||||||
|
dummy_length = T.autograd.Variable((T.arange(5,9)).long())
|
||||||
|
mask = ((_sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert output.data[0] == 1.0, "1.0 vs {}".format(output.data[0])
|
||||||
|
|
|
@ -32,7 +32,7 @@ class TestDataset(unittest.TestCase):
|
||||||
c.power
|
c.power
|
||||||
)
|
)
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
dataloader = DataLoader(dataset, batch_size=2,
|
||||||
shuffle=True, collate_fn=dataset.collate_fn,
|
shuffle=True, collate_fn=dataset.collate_fn,
|
||||||
drop_last=True, num_workers=c.num_loader_workers)
|
drop_last=True, num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
|
@ -43,8 +43,10 @@ class TestDataset(unittest.TestCase):
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
item_idx = data[4]
|
mel_lengths = data[4]
|
||||||
|
stop_target = data[5]
|
||||||
|
item_idx = data[6]
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
neg_values = text_input[text_input < 0]
|
||||||
check_count = len(neg_values)
|
check_count = len(neg_values)
|
||||||
assert check_count == 0, \
|
assert check_count == 0, \
|
||||||
|
@ -70,8 +72,9 @@ class TestDataset(unittest.TestCase):
|
||||||
c.power
|
c.power
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Test for batch size 1
|
||||||
dataloader = DataLoader(dataset, batch_size=1,
|
dataloader = DataLoader(dataset, batch_size=1,
|
||||||
shuffle=True, collate_fn=dataset.collate_fn,
|
shuffle=False, collate_fn=dataset.collate_fn,
|
||||||
drop_last=True, num_workers=c.num_loader_workers)
|
drop_last=True, num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
|
@ -81,13 +84,63 @@ class TestDataset(unittest.TestCase):
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
item_idx = data[4]
|
mel_lengths = data[4]
|
||||||
|
stop_target = data[5]
|
||||||
|
item_idx = data[6]
|
||||||
|
|
||||||
# check the last time step to be zero padded
|
# check the last time step to be zero padded
|
||||||
assert mel_input[0, -1].sum() == 0
|
assert mel_input[0, -1].sum() == 0
|
||||||
assert mel_input[0, -2].sum() != 0
|
assert mel_input[0, -2].sum() != 0
|
||||||
assert linear_input[0, -1].sum() == 0
|
assert linear_input[0, -1].sum() == 0
|
||||||
assert linear_input[0, -2].sum() != 0
|
assert linear_input[0, -2].sum() != 0
|
||||||
|
assert stop_target[0, -1] == 1
|
||||||
|
assert stop_target[0, -2] == 0
|
||||||
|
assert stop_target.sum() == 1
|
||||||
|
assert len(mel_lengths.shape) == 1
|
||||||
|
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||||
|
|
||||||
|
# Test for batch size 2
|
||||||
|
dataloader = DataLoader(dataset, batch_size=2,
|
||||||
|
shuffle=False, collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=False, num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
|
for i, data in enumerate(dataloader):
|
||||||
|
if i == self.max_loader_iter:
|
||||||
|
break
|
||||||
|
text_input = data[0]
|
||||||
|
text_lengths = data[1]
|
||||||
|
linear_input = data[2]
|
||||||
|
mel_input = data[3]
|
||||||
|
mel_lengths = data[4]
|
||||||
|
stop_target = data[5]
|
||||||
|
item_idx = data[6]
|
||||||
|
|
||||||
|
if mel_lengths[0] > mel_lengths[1]:
|
||||||
|
idx = 0
|
||||||
|
else:
|
||||||
|
idx = 1
|
||||||
|
|
||||||
|
# check the first item in the batch
|
||||||
|
assert mel_input[idx, -1].sum() == 0
|
||||||
|
assert mel_input[idx, -2].sum() != 0, mel_input
|
||||||
|
assert linear_input[idx, -1].sum() == 0
|
||||||
|
assert linear_input[idx, -2].sum() != 0
|
||||||
|
assert stop_target[idx, -1] == 1
|
||||||
|
assert stop_target[idx, -2] == 0
|
||||||
|
assert stop_target[idx].sum() == 1
|
||||||
|
assert len(mel_lengths.shape) == 1
|
||||||
|
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||||
|
|
||||||
|
# check the second itme in the batch
|
||||||
|
assert mel_input[1-idx, -1].sum() == 0
|
||||||
|
assert linear_input[1-idx, -1].sum() == 0
|
||||||
|
assert stop_target[1-idx, -1] == 1
|
||||||
|
assert len(mel_lengths.shape) == 1
|
||||||
|
|
||||||
|
# check batch conditions
|
||||||
|
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
|
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
46
train.py
46
train.py
|
@ -26,6 +26,7 @@ from utils.model import get_param_size
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
from datasets.LJSpeech import LJSpeechDataset
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
|
from layers.losses import L1LossMasked
|
||||||
|
|
||||||
|
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
@ -80,7 +81,8 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
|
mel_lengths = data[4]
|
||||||
|
|
||||||
current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1
|
current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1
|
||||||
|
|
||||||
# setup lr
|
# setup lr
|
||||||
|
@ -93,21 +95,14 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
# convert inputs to variables
|
# convert inputs to variables
|
||||||
text_input_var = Variable(text_input)
|
text_input_var = Variable(text_input)
|
||||||
mel_spec_var = Variable(mel_input)
|
mel_spec_var = Variable(mel_input)
|
||||||
|
mel_lengths_var = Variable(mel_lengths)
|
||||||
linear_spec_var = Variable(linear_input, volatile=True)
|
linear_spec_var = Variable(linear_input, volatile=True)
|
||||||
|
|
||||||
# sort sequence by length for curriculum learning
|
|
||||||
# TODO: might be unnecessary
|
|
||||||
sorted_lengths, indices = torch.sort(
|
|
||||||
text_lengths.view(-1), dim=0, descending=True)
|
|
||||||
sorted_lengths = sorted_lengths.long().numpy()
|
|
||||||
text_input_var = text_input_var[indices]
|
|
||||||
mel_spec_var = mel_spec_var[indices]
|
|
||||||
linear_spec_var = linear_spec_var[indices]
|
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input_var = text_input_var.cuda()
|
text_input_var = text_input_var.cuda()
|
||||||
mel_spec_var = mel_spec_var.cuda()
|
mel_spec_var = mel_spec_var.cuda()
|
||||||
|
mel_lengths_var = mel_lengths_var.cuda()
|
||||||
linear_spec_var = linear_spec_var.cuda()
|
linear_spec_var = linear_spec_var.cuda()
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
|
@ -115,10 +110,11 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
model.forward(text_input_var, mel_spec_var)
|
model.forward(text_input_var, mel_spec_var)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
mel_loss = criterion(mel_output, mel_spec_var)
|
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var)
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_spec_var[: ,: ,:n_priority_freq])
|
linear_spec_var[: ,: ,:n_priority_freq],
|
||||||
|
mel_lengths_var)
|
||||||
loss = mel_loss + linear_loss
|
loss = mel_loss + linear_loss
|
||||||
|
|
||||||
# backpass and check the grad norm
|
# backpass and check the grad norm
|
||||||
|
@ -215,28 +211,31 @@ def evaluate(model, criterion, data_loader, current_step):
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
|
mel_lengths = data[4]
|
||||||
|
|
||||||
# convert inputs to variables
|
# convert inputs to variables
|
||||||
text_input_var = Variable(text_input)
|
text_input_var = Variable(text_input)
|
||||||
mel_spec_var = Variable(mel_input)
|
mel_spec_var = Variable(mel_input)
|
||||||
|
mel_lengths_var = Variable(mel_lengths)
|
||||||
linear_spec_var = Variable(linear_input, volatile=True)
|
linear_spec_var = Variable(linear_input, volatile=True)
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input_var = text_input_var.cuda()
|
text_input_var = text_input_var.cuda()
|
||||||
mel_spec_var = mel_spec_var.cuda()
|
mel_spec_var = mel_spec_var.cuda()
|
||||||
|
mel_lengths_var = mel_lengths_var.cuda()
|
||||||
linear_spec_var = linear_spec_var.cuda()
|
linear_spec_var = linear_spec_var.cuda()
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
mel_output, linear_output, alignments =\
|
mel_output, linear_output, alignments = model.forward(text_input_var, mel_spec_var)
|
||||||
model.forward(text_input_var, mel_spec_var)
|
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
mel_loss = criterion(mel_output, mel_spec_var)
|
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var)
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_spec_var[: ,: ,:n_priority_freq])
|
linear_spec_var[: ,: ,:n_priority_freq],
|
||||||
loss = mel_loss + linear_loss
|
mel_lengths_var)
|
||||||
|
loss = mel_loss + linear_loss
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
@ -333,17 +332,16 @@ def main(args):
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
|
|
||||||
model = Tacotron(c.embedding_size,
|
model = Tacotron(c.embedding_size,
|
||||||
c.hidden_size,
|
|
||||||
c.num_mels,
|
|
||||||
c.num_freq,
|
c.num_freq,
|
||||||
|
c.num_mels,
|
||||||
c.r)
|
c.r)
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
criterion = nn.L1Loss().cuda()
|
criterion = L1LossMasked().cuda()
|
||||||
else:
|
else:
|
||||||
criterion = nn.L1Loss()
|
criterion = L1LossMasked()
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = torch.load(args.restore_path)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def pad_data(x, length):
|
def _pad_data(x, length):
|
||||||
_pad = 0
|
_pad = 0
|
||||||
assert x.ndim == 1
|
assert x.ndim == 1
|
||||||
return np.pad(x, (0, length - x.shape[0]),
|
return np.pad(x, (0, length - x.shape[0]),
|
||||||
|
@ -11,7 +11,33 @@ def pad_data(x, length):
|
||||||
|
|
||||||
def prepare_data(inputs):
|
def prepare_data(inputs):
|
||||||
max_len = max((len(x) for x in inputs))
|
max_len = max((len(x) for x in inputs))
|
||||||
return np.stack([pad_data(x, max_len) for x in inputs])
|
return np.stack([_pad_data(x, max_len) for x in inputs])
|
||||||
|
|
||||||
|
|
||||||
|
def _pad_tensor(x, length):
|
||||||
|
_pad = 0
|
||||||
|
assert x.ndim == 2
|
||||||
|
x = np.pad(x, [[0, 0], [0, length - x.shape[1]]], mode='constant', constant_values=_pad)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def prepare_tensor(inputs, out_steps):
|
||||||
|
max_len = max((x.shape[1] for x in inputs)) + 1 # zero-frame
|
||||||
|
remainder = max_len % out_steps
|
||||||
|
pad_len = max_len + (out_steps - remainder) if remainder > 0 else max_len
|
||||||
|
return np.stack([_pad_tensor(x, pad_len) for x in inputs])
|
||||||
|
|
||||||
|
|
||||||
|
def _pad_stop_target(x, length):
|
||||||
|
_pad = 1.
|
||||||
|
assert x.ndim == 1
|
||||||
|
return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_stop_target(inputs, out_steps):
|
||||||
|
max_len = max((x.shape[0] for x in inputs)) + 1 # zero-frame
|
||||||
|
remainder = max_len % out_steps
|
||||||
|
pad_len = max_len + (out_steps - remainder) if remainder > 0 else max_len
|
||||||
|
return np.stack([_pad_stop_target(x, pad_len) for x in inputs])
|
||||||
|
|
||||||
|
|
||||||
def pad_per_step(inputs, pad_len):
|
def pad_per_step(inputs, pad_len):
|
||||||
|
|
Loading…
Reference in New Issue