From 9924ae4a1d9a3f46e03ee7bff57d2483cdf05608 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 2 May 2018 06:07:20 -0700 Subject: [PATCH] bug fix --- train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/train.py b/train.py index 79e12b6a..10a30adb 100644 --- a/train.py +++ b/train.py @@ -239,6 +239,9 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step): mel_lengths = data[4] stop_target = data[5] + stop_target = stop_target.view(c.batch_size, stop_target.size(1) // c.r, -1) + stop_target = (stop_target.sum(2) > 0.0).float().unsqueeze(2) + # dispatch data to GPU if use_cuda: text_input = text_input.cuda()