From fda7e7f6c96849d1113d6d75c5afbeba8df9e16d Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 30 Apr 2018 06:12:12 -0700 Subject: [PATCH] carry to cuda --- train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/train.py b/train.py index de3bfe9c..3e6e6083 100644 --- a/train.py +++ b/train.py @@ -100,6 +100,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch): mel_spec = mel_spec.cuda() mel_lengths = mel_lengths.cuda() linear_spec = linear_spec.cuda() + stop_target = stop_target.cuda() + # create attention mask if c.mk > 0.0: @@ -241,6 +243,7 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step): mel_spec = mel_spec.cuda() mel_lengths = mel_lengths.cuda() linear_spec = linear_spec.cuda() + stop_target = stop_target.cuda() # forward pass mel_output, linear_output, alignments, stop_tokens =\