From 5a5b9263e20176538cec749b0a08ba1e28161a31 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 11 May 2018 04:24:57 -0700 Subject: [PATCH] bug fix train.py add stop token --- train.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/train.py b/train.py index 11a42174..7fa485b6 100644 --- a/train.py +++ b/train.py @@ -77,6 +77,11 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch): linear_input = data[2] mel_input = data[3] mel_lengths = data[4] + stop_targets = data[5] + + # set stop targets view, we predict a single stop token per r frames prediction + stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() current_step = num_iter + args.restore_step + \ epoch * len(data_loader) + 1 @@ -94,13 +99,14 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch): mel_input = mel_input.cuda() mel_lengths = mel_lengths.cuda() linear_input = linear_input.cuda() - + stop_targets = stop_targets.cuda() + # forward pass mel_output, linear_output, alignments, stop_tokens =\ model.forward(text_input, mel_input) # loss computation - stop_loss = criterion_st(stop_tokens, stop_target) + stop_loss = criterion_st(stop_tokens, stop_targets) mel_loss = criterion(mel_output, mel_input, mel_lengths) linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], @@ -208,11 +214,11 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step): linear_input = data[2] mel_input = data[3] mel_lengths = data[4] - stop_target = data[5] + stop_targets = data[5] # set stop targets view, we predict a single stop token per r frames prediction - stop_target = stop_target.view(text_input.shape[0], stop_target.size(1) // c.r, -1) - stop_target = (stop_target.sum(2) > 0.0).float() + stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() # dispatch data to GPU if use_cuda: @@ -220,14 +226,14 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step): mel_input = mel_input.cuda() mel_lengths = mel_lengths.cuda() linear_input = linear_input.cuda() - stop_target = stop_target.cuda() + stop_targets = stop_targets.cuda() # forward pass mel_output, linear_output, alignments, stop_tokens =\ - model.forward(text_input, mel_spec) + model.forward(text_input, mel_input) # loss computation - stop_loss = criterion_st(stop_tokens, stop_target) + stop_loss = criterion_st(stop_tokens, stop_targets) mel_loss = criterion(mel_output, mel_input, mel_lengths) linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \ + 0.5 * criterion(linear_output[:, :, :n_priority_freq],