bug fix train.py add stop token

This commit is contained in:
Eren Golge 2018-05-11 04:24:57 -07:00
parent 8ed9f57a6d
commit 5a5b9263e2
1 changed files with 14 additions and 8 deletions

View File

@ -77,6 +77,11 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_targets = data[5]
# set stop targets view, we predict a single stop token per r frames prediction
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
current_step = num_iter + args.restore_step + \
epoch * len(data_loader) + 1
@ -94,13 +99,14 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda()
stop_targets = stop_targets.cuda()
# forward pass
mel_output, linear_output, alignments, stop_tokens =\
model.forward(text_input, mel_input)
# loss computation
stop_loss = criterion_st(stop_tokens, stop_target)
stop_loss = criterion_st(stop_tokens, stop_targets)
mel_loss = criterion(mel_output, mel_input, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
@ -208,11 +214,11 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step):
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
stop_targets = data[5]
# set stop targets view, we predict a single stop token per r frames prediction
stop_target = stop_target.view(text_input.shape[0], stop_target.size(1) // c.r, -1)
stop_target = (stop_target.sum(2) > 0.0).float()
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
# dispatch data to GPU
if use_cuda:
@ -220,14 +226,14 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step):
mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda()
stop_target = stop_target.cuda()
stop_targets = stop_targets.cuda()
# forward pass
mel_output, linear_output, alignments, stop_tokens =\
model.forward(text_input, mel_spec)
model.forward(text_input, mel_input)
# loss computation
stop_loss = criterion_st(stop_tokens, stop_target)
stop_loss = criterion_st(stop_tokens, stop_targets)
mel_loss = criterion(mel_output, mel_input, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],