mirror of https://github.com/coqui-ai/TTS.git
Stop token prediction - does train yet
This commit is contained in:
parent
b407779a83
commit
79123da4c6
|
@ -12,16 +12,16 @@
|
||||||
"text_cleaner": "english_cleaners",
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
"epochs": 2000,
|
"epochs": 2000,
|
||||||
"lr": 0.001,
|
"lr": 0.0003,
|
||||||
"warmup_steps": 4000,
|
"warmup_steps": 4000,
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"eval_batch_size": 32,
|
"eval_batch_size":32,
|
||||||
"r": 5,
|
"r": 5,
|
||||||
|
|
||||||
"griffin_lim_iters": 60,
|
"griffin_lim_iters": 60,
|
||||||
"power": 1.5,
|
"power": 1.5,
|
||||||
|
|
||||||
"num_loader_workers": 12,
|
"num_loader_workers": 8,
|
||||||
|
|
||||||
"checkpoint": false,
|
"checkpoint": false,
|
||||||
"save_step": 69,
|
"save_step": 69,
|
||||||
|
|
|
@ -7,7 +7,8 @@ from torch.utils.data import Dataset
|
||||||
|
|
||||||
from TTS.utils.text import text_to_sequence
|
from TTS.utils.text import text_to_sequence
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.data import prepare_data, pad_data, pad_per_step
|
from TTS.utils.data import (prepare_data, pad_data, pad_per_step,
|
||||||
|
prepare_tensor, prepare_stop_target)
|
||||||
|
|
||||||
|
|
||||||
class LJSpeechDataset(Dataset):
|
class LJSpeechDataset(Dataset):
|
||||||
|
@ -93,15 +94,26 @@ class LJSpeechDataset(Dataset):
|
||||||
text_lenghts = np.array([len(x) for x in text])
|
text_lenghts = np.array([len(x) for x in text])
|
||||||
max_text_len = np.max(text_lenghts)
|
max_text_len = np.max(text_lenghts)
|
||||||
|
|
||||||
|
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||||
|
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||||
|
mel_lengths = [m.shape[1] for m in mel]
|
||||||
|
|
||||||
|
# compute 'stop token' targets
|
||||||
|
stop_targets = [np.array([0.]*mel_len) for mel_len in mel_lengths]
|
||||||
|
|
||||||
# PAD sequences with largest length of the batch
|
# PAD sequences with largest length of the batch
|
||||||
text = prepare_data(text).astype(np.int32)
|
text = prepare_data(text).astype(np.int32)
|
||||||
wav = prepare_data(wav)
|
wav = prepare_data(wav)
|
||||||
|
|
||||||
linear = np.array([self.ap.spectrogram(w).astype('float32') for w in wav])
|
# PAD features with largest length of the batch
|
||||||
mel = np.array([self.ap.melspectrogram(w).astype('float32') for w in wav])
|
linear = prepare_tensor(linear)
|
||||||
|
mel = prepare_tensor(mel)
|
||||||
assert mel.shape[2] == linear.shape[2]
|
assert mel.shape[2] == linear.shape[2]
|
||||||
timesteps = mel.shape[2]
|
timesteps = mel.shape[2]
|
||||||
|
|
||||||
|
# PAD stop targets
|
||||||
|
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
||||||
|
|
||||||
# PAD with zeros that can be divided by outputs per step
|
# PAD with zeros that can be divided by outputs per step
|
||||||
if (timesteps + 1) % self.outputs_per_step != 0:
|
if (timesteps + 1) % self.outputs_per_step != 0:
|
||||||
pad_len = self.outputs_per_step - \
|
pad_len = self.outputs_per_step - \
|
||||||
|
@ -112,7 +124,7 @@ class LJSpeechDataset(Dataset):
|
||||||
linear = pad_per_step(linear, pad_len)
|
linear = pad_per_step(linear, pad_len)
|
||||||
mel = pad_per_step(mel, pad_len)
|
mel = pad_per_step(mel, pad_len)
|
||||||
|
|
||||||
# reshape jombo
|
# reshape mojo
|
||||||
linear = linear.transpose(0, 2, 1)
|
linear = linear.transpose(0, 2, 1)
|
||||||
mel = mel.transpose(0, 2, 1)
|
mel = mel.transpose(0, 2, 1)
|
||||||
|
|
||||||
|
@ -121,7 +133,8 @@ class LJSpeechDataset(Dataset):
|
||||||
text = torch.LongTensor(text)
|
text = torch.LongTensor(text)
|
||||||
linear = torch.FloatTensor(linear)
|
linear = torch.FloatTensor(linear)
|
||||||
mel = torch.FloatTensor(mel)
|
mel = torch.FloatTensor(mel)
|
||||||
return text, text_lenghts, linear, mel, item_idxs[0]
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
return text, text_lenghts, linear, mel, stop_targets, item_idxs[0]
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
found {}"
|
found {}"
|
||||||
|
|
Binary file not shown.
|
@ -5,6 +5,7 @@ from torch import nn
|
||||||
|
|
||||||
from .attention import AttentionRNN
|
from .attention import AttentionRNN
|
||||||
from .attention import get_mask_from_lengths
|
from .attention import get_mask_from_lengths
|
||||||
|
from .custom_layers import StopProjection
|
||||||
|
|
||||||
class Prenet(nn.Module):
|
class Prenet(nn.Module):
|
||||||
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
||||||
|
@ -214,8 +215,9 @@ class Decoder(nn.Module):
|
||||||
r (int): number of outputs per time step.
|
r (int): number of outputs per time step.
|
||||||
eps (float): threshold for detecting the end of a sentence.
|
eps (float): threshold for detecting the end of a sentence.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_features, memory_dim, r, eps=0.05):
|
def __init__(self, in_features, memory_dim, r, eps=0.05, mode='train'):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
|
self.mode = mode
|
||||||
self.max_decoder_steps = 200
|
self.max_decoder_steps = 200
|
||||||
self.memory_dim = memory_dim
|
self.memory_dim = memory_dim
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
@ -231,6 +233,8 @@ class Decoder(nn.Module):
|
||||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||||
# RNN_state -> |Linear| -> mel_spec
|
# RNN_state -> |Linear| -> mel_spec
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
|
# RNN_state | attention_context -> |Linear| -> stop_token
|
||||||
|
self.stop_token = StopProjection(256 + in_features, r)
|
||||||
|
|
||||||
def forward(self, inputs, memory=None):
|
def forward(self, inputs, memory=None):
|
||||||
"""
|
"""
|
||||||
|
@ -252,10 +256,9 @@ class Decoder(nn.Module):
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
|
|
||||||
# Run greedy decoding if memory is None
|
# Run greedy decoding if memory is None
|
||||||
greedy = memory is None
|
greedy = ~self.training
|
||||||
|
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
|
|
||||||
# Grouping multiple frames if necessary
|
# Grouping multiple frames if necessary
|
||||||
if memory.size(-1) == self.memory_dim:
|
if memory.size(-1) == self.memory_dim:
|
||||||
memory = memory.view(B, memory.size(1) // self.r, -1)
|
memory = memory.view(B, memory.size(1) // self.r, -1)
|
||||||
|
@ -283,6 +286,7 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
alignments = []
|
alignments = []
|
||||||
|
stop_outputs = []
|
||||||
|
|
||||||
t = 0
|
t = 0
|
||||||
memory_input = initial_memory
|
memory_input = initial_memory
|
||||||
|
@ -292,11 +296,12 @@ class Decoder(nn.Module):
|
||||||
memory_input = outputs[-1]
|
memory_input = outputs[-1]
|
||||||
else:
|
else:
|
||||||
# combine prev. model output and prev. real target
|
# combine prev. model output and prev. real target
|
||||||
memory_input = torch.div(outputs[-1] + memory[t-1], 2.0)
|
# memory_input = torch.div(outputs[-1] + memory[t-1], 2.0)
|
||||||
# add a random noise
|
# add a random noise
|
||||||
noise = torch.autograd.Variable(
|
# noise = torch.autograd.Variable(
|
||||||
memory_input.data.new(memory_input.size()).normal_(0.0, 0.5))
|
# memory_input.data.new(memory_input.size()).normal_(0.0, 0.5))
|
||||||
memory_input = memory_input + noise
|
# memory_input = memory_input + noise
|
||||||
|
memory_input = memory[t-1]
|
||||||
|
|
||||||
# Prenet
|
# Prenet
|
||||||
processed_memory = self.prenet(memory_input)
|
processed_memory = self.prenet(memory_input)
|
||||||
|
@ -316,35 +321,42 @@ class Decoder(nn.Module):
|
||||||
decoder_input, decoder_rnn_hiddens[idx])
|
decoder_input, decoder_rnn_hiddens[idx])
|
||||||
# Residual connectinon
|
# Residual connectinon
|
||||||
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
|
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
|
||||||
|
|
||||||
output = decoder_input
|
output = decoder_input
|
||||||
|
stop_token_input = decoder_input
|
||||||
|
|
||||||
|
# stop token prediction
|
||||||
|
stop_token_input = torch.cat((output, current_context_vec), -1)
|
||||||
|
stop_output = self.stop_token(stop_token_input)
|
||||||
|
|
||||||
# predict mel vectors from decoder vectors
|
# predict mel vectors from decoder vectors
|
||||||
output = self.proj_to_mel(output)
|
output = self.proj_to_mel(output)
|
||||||
|
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
stop_outputs += [stop_output]
|
||||||
|
|
||||||
t += 1
|
t += 1
|
||||||
|
|
||||||
if greedy:
|
if (not greedy and self.training) or (greedy and memory is not None):
|
||||||
|
if t >= T_decoder:
|
||||||
|
break
|
||||||
|
else:
|
||||||
if t > 1 and is_end_of_frames(output, self.eps):
|
if t > 1 and is_end_of_frames(output, self.eps):
|
||||||
break
|
break
|
||||||
elif t > self.max_decoder_steps:
|
elif t > self.max_decoder_steps:
|
||||||
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
||||||
Something is probably wrong.")
|
Something is probably wrong.")
|
||||||
break
|
break
|
||||||
else:
|
|
||||||
if t >= T_decoder:
|
|
||||||
break
|
|
||||||
|
|
||||||
assert greedy or len(outputs) == T_decoder
|
assert greedy or len(outputs) == T_decoder
|
||||||
|
|
||||||
# Back to batch first
|
# Back to batch first
|
||||||
alignments = torch.stack(alignments).transpose(0, 1)
|
alignments = torch.stack(alignments).transpose(0, 1)
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
|
stop_outputs = torch.stack(stop_outputs).transpose(0, 1).contiguous()
|
||||||
|
|
||||||
return outputs, alignments
|
return outputs, alignments, stop_outputs
|
||||||
|
|
||||||
|
|
||||||
def is_end_of_frames(output, eps=0.2): #0.2
|
def is_end_of_frames(output, eps=0.2): #0.2
|
||||||
|
|
Binary file not shown.
|
@ -11,6 +11,7 @@ class Tacotron(nn.Module):
|
||||||
freq_dim=1025, r=5, padding_idx=None):
|
freq_dim=1025, r=5, padding_idx=None):
|
||||||
|
|
||||||
super(Tacotron, self).__init__()
|
super(Tacotron, self).__init__()
|
||||||
|
self.r = r
|
||||||
self.mel_dim = mel_dim
|
self.mel_dim = mel_dim
|
||||||
self.linear_dim = linear_dim
|
self.linear_dim = linear_dim
|
||||||
self.embedding = nn.Embedding(len(symbols), embedding_dim,
|
self.embedding = nn.Embedding(len(symbols), embedding_dim,
|
||||||
|
@ -26,6 +27,7 @@ class Tacotron(nn.Module):
|
||||||
self.last_linear = nn.Linear(mel_dim * 2, freq_dim)
|
self.last_linear = nn.Linear(mel_dim * 2, freq_dim)
|
||||||
|
|
||||||
def forward(self, characters, mel_specs=None):
|
def forward(self, characters, mel_specs=None):
|
||||||
|
|
||||||
B = characters.size(0)
|
B = characters.size(0)
|
||||||
|
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
|
@ -33,7 +35,7 @@ class Tacotron(nn.Module):
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
|
|
||||||
# (B, T', mel_dim*r)
|
# (B, T', mel_dim*r)
|
||||||
mel_outputs, alignments = self.decoder(
|
mel_outputs, alignments, stop_outputs = self.decoder(
|
||||||
encoder_outputs, mel_specs)
|
encoder_outputs, mel_specs)
|
||||||
|
|
||||||
# Post net processing below
|
# Post net processing below
|
||||||
|
@ -41,8 +43,9 @@ class Tacotron(nn.Module):
|
||||||
# Reshape
|
# Reshape
|
||||||
# (B, T, mel_dim)
|
# (B, T, mel_dim)
|
||||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||||
|
stop_outputs = stop_outputs.view(B, -1)
|
||||||
|
|
||||||
linear_outputs = self.postnet(mel_outputs)
|
linear_outputs = self.postnet(mel_outputs)
|
||||||
linear_outputs = self.last_linear(linear_outputs)
|
linear_outputs = self.last_linear(linear_outputs)
|
||||||
|
|
||||||
return mel_outputs, linear_outputs, alignments
|
return mel_outputs, linear_outputs, alignments, stop_outputs
|
||||||
|
|
|
@ -37,18 +37,23 @@ class DecoderTests(unittest.TestCase):
|
||||||
dummy_memory = T.autograd.Variable(T.rand(4, 120, 32))
|
dummy_memory = T.autograd.Variable(T.rand(4, 120, 32))
|
||||||
|
|
||||||
print(layer)
|
print(layer)
|
||||||
output, alignment = layer(dummy_input, dummy_memory)
|
output, alignment, stop_output = layer(dummy_input, dummy_memory)
|
||||||
print(output.shape)
|
print(output.shape)
|
||||||
|
print(" > Stop ", stop_output.shape)
|
||||||
|
|
||||||
assert output.shape[0] == 4
|
assert output.shape[0] == 4
|
||||||
assert output.shape[1] == 120 / 5
|
assert output.shape[1] == 120 / 5
|
||||||
assert output.shape[2] == 32 * 5
|
assert output.shape[2] == 32 * 5
|
||||||
|
assert stop_output.shape[0] == 4
|
||||||
|
assert stop_output.shape[1] == 120 / 5
|
||||||
|
assert stop_output.shape[2] == 5
|
||||||
|
|
||||||
|
|
||||||
class EncoderTests(unittest.TestCase):
|
class EncoderTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Encoder(128)
|
layer = Encoder(128)
|
||||||
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
||||||
|
|
||||||
print(layer)
|
print(layer)
|
||||||
output = layer(dummy_input)
|
output = layer(dummy_input)
|
||||||
|
|
|
@ -32,7 +32,7 @@ class TestDataset(unittest.TestCase):
|
||||||
c.power
|
c.power
|
||||||
)
|
)
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
dataloader = DataLoader(dataset, batch_size=2,
|
||||||
shuffle=True, collate_fn=dataset.collate_fn,
|
shuffle=True, collate_fn=dataset.collate_fn,
|
||||||
drop_last=True, num_workers=c.num_loader_workers)
|
drop_last=True, num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
|
@ -43,7 +43,8 @@ class TestDataset(unittest.TestCase):
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
item_idx = data[4]
|
stop_targets = data[4]
|
||||||
|
item_idx = data[5]
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
neg_values = text_input[text_input < 0]
|
||||||
check_count = len(neg_values)
|
check_count = len(neg_values)
|
||||||
|
@ -81,13 +82,16 @@ class TestDataset(unittest.TestCase):
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
item_idx = data[4]
|
stop_target = data[4]
|
||||||
|
item_idx = data[5]
|
||||||
|
|
||||||
# check the last time step to be zero padded
|
# check the last time step to be zero padded
|
||||||
assert mel_input[0, -1].sum() == 0
|
assert mel_input[0, -1].sum() == 0
|
||||||
assert mel_input[0, -2].sum() != 0
|
assert mel_input[0, -2].sum() != 0
|
||||||
assert linear_input[0, -1].sum() == 0
|
assert linear_input[0, -1].sum() == 0
|
||||||
assert linear_input[0, -2].sum() != 0
|
assert linear_input[0, -2].sum() != 0
|
||||||
|
assert stop_target[0, -1] == 1
|
||||||
|
assert stop_target.sum() == 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
42
train.py
42
train.py
|
@ -63,11 +63,12 @@ def signal_handler(signal, frame):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, data_loader, optimizer, epoch):
|
def train(model, criterion, critetion_stop, data_loader, optimizer, epoch):
|
||||||
model = model.train()
|
model = model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
avg_linear_loss = 0
|
avg_linear_loss = 0
|
||||||
avg_mel_loss = 0
|
avg_mel_loss = 0
|
||||||
|
avg_stop_loss = 0
|
||||||
|
|
||||||
print(" | > Epoch {}/{}".format(epoch, c.epochs))
|
print(" | > Epoch {}/{}".format(epoch, c.epochs))
|
||||||
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
||||||
|
@ -80,6 +81,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
|
stop_targets = data[4]
|
||||||
|
|
||||||
current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1
|
current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1
|
||||||
|
|
||||||
|
@ -93,6 +95,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
# convert inputs to variables
|
# convert inputs to variables
|
||||||
text_input_var = Variable(text_input)
|
text_input_var = Variable(text_input)
|
||||||
mel_spec_var = Variable(mel_input)
|
mel_spec_var = Variable(mel_input)
|
||||||
|
stop_targets_var = Variable(stop_targets)
|
||||||
linear_spec_var = Variable(linear_input, volatile=True)
|
linear_spec_var = Variable(linear_input, volatile=True)
|
||||||
|
|
||||||
# sort sequence by length for curriculum learning
|
# sort sequence by length for curriculum learning
|
||||||
|
@ -109,9 +112,10 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
text_input_var = text_input_var.cuda()
|
text_input_var = text_input_var.cuda()
|
||||||
mel_spec_var = mel_spec_var.cuda()
|
mel_spec_var = mel_spec_var.cuda()
|
||||||
linear_spec_var = linear_spec_var.cuda()
|
linear_spec_var = linear_spec_var.cuda()
|
||||||
|
stop_targets_var = stop_targets_var.cuda()
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
mel_output, linear_output, alignments =\
|
mel_output, linear_output, alignments, stop_output =\
|
||||||
model.forward(text_input_var, mel_spec_var)
|
model.forward(text_input_var, mel_spec_var)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
|
@ -119,7 +123,8 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_spec_var[: ,: ,:n_priority_freq])
|
linear_spec_var[: ,: ,:n_priority_freq])
|
||||||
loss = mel_loss + linear_loss
|
stop_loss = critetion_stop(stop_output, stop_targets_var)
|
||||||
|
loss = mel_loss + linear_loss + 0.25*stop_loss
|
||||||
|
|
||||||
# backpass and check the grad norm
|
# backpass and check the grad norm
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
@ -136,6 +141,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
# update
|
# update
|
||||||
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
||||||
('linear_loss', linear_loss.data[0]),
|
('linear_loss', linear_loss.data[0]),
|
||||||
|
('stop_loss', stop_loss.data[0]),
|
||||||
('mel_loss', mel_loss.data[0]),
|
('mel_loss', mel_loss.data[0]),
|
||||||
('grad_norm', grad_norm)])
|
('grad_norm', grad_norm)])
|
||||||
|
|
||||||
|
@ -144,6 +150,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.data[0],
|
tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.data[0],
|
||||||
current_step)
|
current_step)
|
||||||
tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.data[0], current_step)
|
tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.data[0], current_step)
|
||||||
|
tb.add_scalar('TrainIterLoss/StopLoss', stop_loss.data[0], current_step)
|
||||||
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
|
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
|
||||||
current_step)
|
current_step)
|
||||||
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
||||||
|
@ -184,19 +191,21 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
|
|
||||||
avg_linear_loss /= (num_iter + 1)
|
avg_linear_loss /= (num_iter + 1)
|
||||||
avg_mel_loss /= (num_iter + 1)
|
avg_mel_loss /= (num_iter + 1)
|
||||||
avg_total_loss = avg_mel_loss + avg_linear_loss
|
avg_stop_loss /= (num_iter + 1)
|
||||||
|
avg_total_loss = avg_mel_loss + avg_linear_loss + 0.25*avg_stop_loss
|
||||||
|
|
||||||
# Plot Training Epoch Stats
|
# Plot Training Epoch Stats
|
||||||
tb.add_scalar('TrainEpochLoss/TotalLoss', loss.data[0], current_step)
|
tb.add_scalar('TrainEpochLoss/TotalLoss', loss.data[0], current_step)
|
||||||
tb.add_scalar('TrainEpochLoss/LinearLoss', linear_loss.data[0], current_step)
|
tb.add_scalar('TrainEpochLoss/LinearLoss', linear_loss.data[0], current_step)
|
||||||
tb.add_scalar('TrainEpochLoss/MelLoss', mel_loss.data[0], current_step)
|
tb.add_scalar('TrainEpochLoss/MelLoss', mel_loss.data[0], current_step)
|
||||||
|
tb.add_scalar('TrainEpochLoss/StopLoss', stop_loss.data[0], current_step)
|
||||||
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
|
|
||||||
return avg_linear_loss, current_step
|
return avg_linear_loss, current_step
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, criterion, data_loader, current_step):
|
def evaluate(model, criterion, criterion_stop, data_loader, current_step):
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
|
|
||||||
|
@ -206,6 +215,7 @@ def evaluate(model, criterion, data_loader, current_step):
|
||||||
|
|
||||||
avg_linear_loss = 0
|
avg_linear_loss = 0
|
||||||
avg_mel_loss = 0
|
avg_mel_loss = 0
|
||||||
|
avg_stop_loss = 0
|
||||||
|
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -215,38 +225,44 @@ def evaluate(model, criterion, data_loader, current_step):
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
|
stop_targets = data[4]
|
||||||
|
|
||||||
# convert inputs to variables
|
# convert inputs to variables
|
||||||
text_input_var = Variable(text_input)
|
text_input_var = Variable(text_input)
|
||||||
mel_spec_var = Variable(mel_input)
|
mel_spec_var = Variable(mel_input)
|
||||||
linear_spec_var = Variable(linear_input, volatile=True)
|
linear_spec_var = Variable(linear_input, volatile=True)
|
||||||
|
stop_targets_var = Variable(stop_targets)
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input_var = text_input_var.cuda()
|
text_input_var = text_input_var.cuda()
|
||||||
mel_spec_var = mel_spec_var.cuda()
|
mel_spec_var = mel_spec_var.cuda()
|
||||||
linear_spec_var = linear_spec_var.cuda()
|
linear_spec_var = linear_spec_var.cuda()
|
||||||
|
stop_targets_var = stop_targets_var.cuda()
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
mel_output, linear_output, alignments = model.forward(text_input_var)
|
mel_output, linear_output, alignments, stop_output = model.forward(text_input_var, mel_spec_var)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
mel_loss = criterion(mel_output, mel_spec_var)
|
mel_loss = criterion(mel_output, mel_spec_var)
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_spec_var[: ,: ,:n_priority_freq])
|
linear_spec_var[: ,: ,:n_priority_freq])
|
||||||
loss = mel_loss + linear_loss
|
stop_loss = criterion_stop(stop_output, stop_targets_var)
|
||||||
|
loss = mel_loss + linear_loss + 0.25*stop_loss
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
# update
|
# update
|
||||||
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
||||||
|
('stop_loss', stop_loss.data[0]),
|
||||||
('linear_loss', linear_loss.data[0]),
|
('linear_loss', linear_loss.data[0]),
|
||||||
('mel_loss', mel_loss.data[0])])
|
('mel_loss', mel_loss.data[0])])
|
||||||
|
|
||||||
avg_linear_loss += linear_loss.data[0]
|
avg_linear_loss += linear_loss.data[0]
|
||||||
avg_mel_loss += mel_loss.data[0]
|
avg_mel_loss += mel_loss.data[0]
|
||||||
|
avg_stop_loss += stop_loss.data[0]
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
idx = np.random.randint(mel_input.shape[0])
|
idx = np.random.randint(mel_input.shape[0])
|
||||||
|
@ -278,12 +294,14 @@ def evaluate(model, criterion, data_loader, current_step):
|
||||||
# compute average losses
|
# compute average losses
|
||||||
avg_linear_loss /= (num_iter + 1)
|
avg_linear_loss /= (num_iter + 1)
|
||||||
avg_mel_loss /= (num_iter + 1)
|
avg_mel_loss /= (num_iter + 1)
|
||||||
avg_total_loss = avg_mel_loss + avg_linear_loss
|
avg_stop_loss /= (num_iter + 1)
|
||||||
|
avg_total_loss = avg_mel_loss + avg_linear_loss + 0.25*avg_stop_loss
|
||||||
|
|
||||||
# Plot Learning Stats
|
# Plot Learning Stats
|
||||||
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step)
|
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step)
|
||||||
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
||||||
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
|
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
|
||||||
|
tb.add_scalar('ValEpochLoss/StopLoss', avg_stop_loss, current_step)
|
||||||
return avg_linear_loss
|
return avg_linear_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -336,13 +354,15 @@ def main(args):
|
||||||
c.num_mels,
|
c.num_mels,
|
||||||
c.num_freq,
|
c.num_freq,
|
||||||
c.r)
|
c.r)
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
criterion = nn.L1Loss().cuda()
|
criterion = nn.L1Loss().cuda()
|
||||||
|
criterion_stop = nn.BCELoss().cuda()
|
||||||
else:
|
else:
|
||||||
criterion = nn.L1Loss()
|
criterion = nn.L1Loss()
|
||||||
|
criterion_stop = nn.BCELoss()
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = torch.load(args.restore_path)
|
||||||
|
@ -370,8 +390,8 @@ def main(args):
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
train_loss, current_step = train(model, criterion, train_loader, optimizer, epoch)
|
train_loss, current_step = train(model, criterion, criterion_stop, train_loader, optimizer, epoch)
|
||||||
val_loss = evaluate(model, criterion, val_loader, current_step)
|
val_loss = evaluate(model, criterion, criterion_stop, val_loader, current_step)
|
||||||
best_loss = save_best_model(model, optimizer, val_loss,
|
best_loss = save_best_model(model, optimizer, val_loss,
|
||||||
best_loss, OUT_PATH,
|
best_loss, OUT_PATH,
|
||||||
current_step, epoch)
|
current_step, epoch)
|
||||||
|
|
|
@ -14,6 +14,29 @@ def prepare_data(inputs):
|
||||||
return np.stack([pad_data(x, max_len) for x in inputs])
|
return np.stack([pad_data(x, max_len) for x in inputs])
|
||||||
|
|
||||||
|
|
||||||
|
def pad_tensor(x, length):
|
||||||
|
_pad = 0
|
||||||
|
assert x.ndim == 2
|
||||||
|
return np.pad(x, [[0, 0], [0, length- x.shape[1]]], mode='constant', constant_values=_pad)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_tensor(inputs):
|
||||||
|
max_len = max((x.shape[1] for x in inputs))
|
||||||
|
return np.stack([pad_tensor(x, max_len) for x in inputs])
|
||||||
|
|
||||||
|
|
||||||
|
def pad_stop_target(x, length):
|
||||||
|
_pad = 1.
|
||||||
|
assert x.ndim == 1
|
||||||
|
return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_stop_target(inputs, out_steps):
|
||||||
|
max_len = max((x.shape[0] for x in inputs))
|
||||||
|
remainder = max_len % out_steps
|
||||||
|
return np.stack([pad_stop_target(x, max_len + out_steps - remainder) for x in inputs])
|
||||||
|
|
||||||
|
|
||||||
def pad_per_step(inputs, pad_len):
|
def pad_per_step(inputs, pad_len):
|
||||||
timesteps = inputs.shape[-1]
|
timesteps = inputs.shape[-1]
|
||||||
return np.pad(inputs, [[0, 0], [0, 0],
|
return np.pad(inputs, [[0, 0], [0, 0],
|
||||||
|
|
Loading…
Reference in New Issue