mirror of https://github.com/coqui-ai/TTS.git
remove stop token prediciton
This commit is contained in:
parent
b9fbdfa7ce
commit
802c1cc5b4
|
@ -4,23 +4,23 @@ from torch.autograd import Variable
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
class StopProjection(nn.Module):
|
# class StopProjection(nn.Module):
|
||||||
r""" Simple projection layer to predict the "stop token"
|
# r""" Simple projection layer to predict the "stop token"
|
||||||
|
|
||||||
Args:
|
# Args:
|
||||||
in_features (int): size of the input vector
|
# in_features (int): size of the input vector
|
||||||
out_features (int or list): size of each output vector. aka number
|
# out_features (int or list): size of each output vector. aka number
|
||||||
of predicted frames.
|
# of predicted frames.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
def __init__(self, in_features, out_features):
|
# def __init__(self, in_features, out_features):
|
||||||
super(StopProjection, self).__init__()
|
# super(StopProjection, self).__init__()
|
||||||
self.linear = nn.Linear(in_features, out_features)
|
# self.linear = nn.Linear(in_features, out_features)
|
||||||
self.dropout = nn.Dropout(0.5)
|
# self.dropout = nn.Dropout(0.5)
|
||||||
self.sigmoid = nn.Sigmoid()
|
# self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
def forward(self, inputs):
|
# def forward(self, inputs):
|
||||||
out = self.dropout(inputs)
|
# out = self.dropout(inputs)
|
||||||
out = self.linear(out)
|
# out = self.linear(out)
|
||||||
out = self.sigmoid(out)
|
# out = self.sigmoid(out)
|
||||||
return out
|
# return out
|
|
@ -5,7 +5,6 @@ from torch import nn
|
||||||
|
|
||||||
from .attention import AttentionRNN
|
from .attention import AttentionRNN
|
||||||
from .attention import get_mask_from_lengths
|
from .attention import get_mask_from_lengths
|
||||||
from .custom_layers import StopProjection
|
|
||||||
|
|
||||||
class Prenet(nn.Module):
|
class Prenet(nn.Module):
|
||||||
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
||||||
|
@ -233,8 +232,6 @@ class Decoder(nn.Module):
|
||||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||||
# RNN_state -> |Linear| -> mel_spec
|
# RNN_state -> |Linear| -> mel_spec
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
# RNN_state | attention_context -> |Linear| -> stop_token
|
|
||||||
self.stop_token = StopProjection(256 + in_features, r)
|
|
||||||
|
|
||||||
def forward(self, inputs, memory=None):
|
def forward(self, inputs, memory=None):
|
||||||
"""
|
"""
|
||||||
|
@ -286,7 +283,6 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
alignments = []
|
alignments = []
|
||||||
stop_outputs = []
|
|
||||||
|
|
||||||
t = 0
|
t = 0
|
||||||
memory_input = initial_memory
|
memory_input = initial_memory
|
||||||
|
@ -323,18 +319,13 @@ class Decoder(nn.Module):
|
||||||
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
|
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
|
||||||
|
|
||||||
output = decoder_input
|
output = decoder_input
|
||||||
stop_token_input = decoder_input
|
|
||||||
|
|
||||||
# stop token prediction
|
|
||||||
stop_token_input = torch.cat((output, current_context_vec), -1)
|
|
||||||
stop_output = self.stop_token(stop_token_input)
|
|
||||||
|
|
||||||
# predict mel vectors from decoder vectors
|
# predict mel vectors from decoder vectors
|
||||||
output = self.proj_to_mel(output)
|
output = self.proj_to_mel(output)
|
||||||
|
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
stop_outputs += [stop_output]
|
|
||||||
|
|
||||||
t += 1
|
t += 1
|
||||||
|
|
||||||
|
@ -354,9 +345,8 @@ class Decoder(nn.Module):
|
||||||
# Back to batch first
|
# Back to batch first
|
||||||
alignments = torch.stack(alignments).transpose(0, 1)
|
alignments = torch.stack(alignments).transpose(0, 1)
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
stop_outputs = torch.stack(stop_outputs).transpose(0, 1).contiguous()
|
|
||||||
|
|
||||||
return outputs, alignments, stop_outputs
|
return outputs, alignments
|
||||||
|
|
||||||
|
|
||||||
def is_end_of_frames(output, eps=0.2): #0.2
|
def is_end_of_frames(output, eps=0.2): #0.2
|
||||||
|
|
40
train.py
40
train.py
|
@ -63,12 +63,11 @@ def signal_handler(signal, frame):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, critetion_stop, data_loader, optimizer, epoch):
|
def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
model = model.train()
|
model = model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
avg_linear_loss = 0
|
avg_linear_loss = 0
|
||||||
avg_mel_loss = 0
|
avg_mel_loss = 0
|
||||||
avg_stop_loss = 0
|
|
||||||
|
|
||||||
print(" | > Epoch {}/{}".format(epoch, c.epochs))
|
print(" | > Epoch {}/{}".format(epoch, c.epochs))
|
||||||
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
||||||
|
@ -81,7 +80,6 @@ def train(model, criterion, critetion_stop, data_loader, optimizer, epoch):
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
stop_targets = data[4]
|
|
||||||
|
|
||||||
current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1
|
current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1
|
||||||
|
|
||||||
|
@ -95,7 +93,6 @@ def train(model, criterion, critetion_stop, data_loader, optimizer, epoch):
|
||||||
# convert inputs to variables
|
# convert inputs to variables
|
||||||
text_input_var = Variable(text_input)
|
text_input_var = Variable(text_input)
|
||||||
mel_spec_var = Variable(mel_input)
|
mel_spec_var = Variable(mel_input)
|
||||||
stop_targets_var = Variable(stop_targets)
|
|
||||||
linear_spec_var = Variable(linear_input, volatile=True)
|
linear_spec_var = Variable(linear_input, volatile=True)
|
||||||
|
|
||||||
# sort sequence by length for curriculum learning
|
# sort sequence by length for curriculum learning
|
||||||
|
@ -112,10 +109,9 @@ def train(model, criterion, critetion_stop, data_loader, optimizer, epoch):
|
||||||
text_input_var = text_input_var.cuda()
|
text_input_var = text_input_var.cuda()
|
||||||
mel_spec_var = mel_spec_var.cuda()
|
mel_spec_var = mel_spec_var.cuda()
|
||||||
linear_spec_var = linear_spec_var.cuda()
|
linear_spec_var = linear_spec_var.cuda()
|
||||||
stop_targets_var = stop_targets_var.cuda()
|
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
mel_output, linear_output, alignments, stop_output =\
|
mel_output, linear_output, alignments =\
|
||||||
model.forward(text_input_var, mel_spec_var)
|
model.forward(text_input_var, mel_spec_var)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
|
@ -123,8 +119,7 @@ def train(model, criterion, critetion_stop, data_loader, optimizer, epoch):
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_spec_var[: ,: ,:n_priority_freq])
|
linear_spec_var[: ,: ,:n_priority_freq])
|
||||||
stop_loss = critetion_stop(stop_output, stop_targets_var)
|
loss = mel_loss + linear_loss
|
||||||
loss = mel_loss + linear_loss + 0.25*stop_loss
|
|
||||||
|
|
||||||
# backpass and check the grad norm
|
# backpass and check the grad norm
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
@ -141,7 +136,6 @@ def train(model, criterion, critetion_stop, data_loader, optimizer, epoch):
|
||||||
# update
|
# update
|
||||||
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
||||||
('linear_loss', linear_loss.data[0]),
|
('linear_loss', linear_loss.data[0]),
|
||||||
('stop_loss', stop_loss.data[0]),
|
|
||||||
('mel_loss', mel_loss.data[0]),
|
('mel_loss', mel_loss.data[0]),
|
||||||
('grad_norm', grad_norm)])
|
('grad_norm', grad_norm)])
|
||||||
|
|
||||||
|
@ -150,7 +144,6 @@ def train(model, criterion, critetion_stop, data_loader, optimizer, epoch):
|
||||||
tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.data[0],
|
tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.data[0],
|
||||||
current_step)
|
current_step)
|
||||||
tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.data[0], current_step)
|
tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.data[0], current_step)
|
||||||
tb.add_scalar('TrainIterLoss/StopLoss', stop_loss.data[0], current_step)
|
|
||||||
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
|
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
|
||||||
current_step)
|
current_step)
|
||||||
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
||||||
|
@ -191,21 +184,19 @@ def train(model, criterion, critetion_stop, data_loader, optimizer, epoch):
|
||||||
|
|
||||||
avg_linear_loss /= (num_iter + 1)
|
avg_linear_loss /= (num_iter + 1)
|
||||||
avg_mel_loss /= (num_iter + 1)
|
avg_mel_loss /= (num_iter + 1)
|
||||||
avg_stop_loss /= (num_iter + 1)
|
avg_total_loss = avg_mel_loss + avg_linear_loss
|
||||||
avg_total_loss = avg_mel_loss + avg_linear_loss + 0.25*avg_stop_loss
|
|
||||||
|
|
||||||
# Plot Training Epoch Stats
|
# Plot Training Epoch Stats
|
||||||
tb.add_scalar('TrainEpochLoss/TotalLoss', loss.data[0], current_step)
|
tb.add_scalar('TrainEpochLoss/TotalLoss', loss.data[0], current_step)
|
||||||
tb.add_scalar('TrainEpochLoss/LinearLoss', linear_loss.data[0], current_step)
|
tb.add_scalar('TrainEpochLoss/LinearLoss', linear_loss.data[0], current_step)
|
||||||
tb.add_scalar('TrainEpochLoss/MelLoss', mel_loss.data[0], current_step)
|
tb.add_scalar('TrainEpochLoss/MelLoss', mel_loss.data[0], current_step)
|
||||||
tb.add_scalar('TrainEpochLoss/StopLoss', stop_loss.data[0], current_step)
|
|
||||||
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
|
|
||||||
return avg_linear_loss, current_step
|
return avg_linear_loss, current_step
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, criterion, criterion_stop, data_loader, current_step):
|
def evaluate(model, criterion, data_loader, current_step):
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
|
|
||||||
|
@ -215,7 +206,6 @@ def evaluate(model, criterion, criterion_stop, data_loader, current_step):
|
||||||
|
|
||||||
avg_linear_loss = 0
|
avg_linear_loss = 0
|
||||||
avg_mel_loss = 0
|
avg_mel_loss = 0
|
||||||
avg_stop_loss = 0
|
|
||||||
|
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -225,44 +215,38 @@ def evaluate(model, criterion, criterion_stop, data_loader, current_step):
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
stop_targets = data[4]
|
|
||||||
|
|
||||||
# convert inputs to variables
|
# convert inputs to variables
|
||||||
text_input_var = Variable(text_input)
|
text_input_var = Variable(text_input)
|
||||||
mel_spec_var = Variable(mel_input)
|
mel_spec_var = Variable(mel_input)
|
||||||
linear_spec_var = Variable(linear_input, volatile=True)
|
linear_spec_var = Variable(linear_input, volatile=True)
|
||||||
stop_targets_var = Variable(stop_targets)
|
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input_var = text_input_var.cuda()
|
text_input_var = text_input_var.cuda()
|
||||||
mel_spec_var = mel_spec_var.cuda()
|
mel_spec_var = mel_spec_var.cuda()
|
||||||
linear_spec_var = linear_spec_var.cuda()
|
linear_spec_var = linear_spec_var.cuda()
|
||||||
stop_targets_var = stop_targets_var.cuda()
|
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
mel_output, linear_output, alignments, stop_output = model.forward(text_input_var, mel_spec_var)
|
mel_output, linear_output, alignments = model.forward(text_input_var, mel_spec_var)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
mel_loss = criterion(mel_output, mel_spec_var)
|
mel_loss = criterion(mel_output, mel_spec_var)
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_spec_var[: ,: ,:n_priority_freq])
|
linear_spec_var[: ,: ,:n_priority_freq])
|
||||||
stop_loss = criterion_stop(stop_output, stop_targets_var)
|
loss = mel_loss + linear_loss
|
||||||
loss = mel_loss + linear_loss + 0.25*stop_loss
|
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
# update
|
# update
|
||||||
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
||||||
('stop_loss', stop_loss.data[0]),
|
|
||||||
('linear_loss', linear_loss.data[0]),
|
('linear_loss', linear_loss.data[0]),
|
||||||
('mel_loss', mel_loss.data[0])])
|
('mel_loss', mel_loss.data[0])])
|
||||||
|
|
||||||
avg_linear_loss += linear_loss.data[0]
|
avg_linear_loss += linear_loss.data[0]
|
||||||
avg_mel_loss += mel_loss.data[0]
|
avg_mel_loss += mel_loss.data[0]
|
||||||
avg_stop_loss += stop_loss.data[0]
|
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
idx = np.random.randint(mel_input.shape[0])
|
idx = np.random.randint(mel_input.shape[0])
|
||||||
|
@ -294,14 +278,12 @@ def evaluate(model, criterion, criterion_stop, data_loader, current_step):
|
||||||
# compute average losses
|
# compute average losses
|
||||||
avg_linear_loss /= (num_iter + 1)
|
avg_linear_loss /= (num_iter + 1)
|
||||||
avg_mel_loss /= (num_iter + 1)
|
avg_mel_loss /= (num_iter + 1)
|
||||||
avg_stop_loss /= (num_iter + 1)
|
avg_total_loss = avg_mel_loss + avg_linear_loss
|
||||||
avg_total_loss = avg_mel_loss + avg_linear_loss + 0.25*avg_stop_loss
|
|
||||||
|
|
||||||
# Plot Learning Stats
|
# Plot Learning Stats
|
||||||
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step)
|
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step)
|
||||||
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
||||||
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
|
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
|
||||||
tb.add_scalar('ValEpochLoss/StopLoss', avg_stop_loss, current_step)
|
|
||||||
return avg_linear_loss
|
return avg_linear_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -359,10 +341,8 @@ def main(args):
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
criterion = nn.L1Loss().cuda()
|
criterion = nn.L1Loss().cuda()
|
||||||
criterion_stop = nn.BCELoss().cuda()
|
|
||||||
else:
|
else:
|
||||||
criterion = nn.L1Loss()
|
criterion = nn.L1Loss()
|
||||||
criterion_stop = nn.BCELoss()
|
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = torch.load(args.restore_path)
|
||||||
|
@ -390,8 +370,8 @@ def main(args):
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
train_loss, current_step = train(model, criterion, criterion_stop, train_loader, optimizer, epoch)
|
train_loss, current_step = train(model, criterion, train_loader, optimizer, epoch)
|
||||||
val_loss = evaluate(model, criterion, criterion_stop, val_loader, current_step)
|
val_loss = evaluate(model, criterion, val_loader, current_step)
|
||||||
best_loss = save_best_model(model, optimizer, val_loss,
|
best_loss = save_best_model(model, optimizer, val_loss,
|
||||||
best_loss, OUT_PATH,
|
best_loss, OUT_PATH,
|
||||||
current_step, epoch)
|
current_step, epoch)
|
||||||
|
|
Loading…
Reference in New Issue