diff --git a/train.py b/train.py index de3bfe9c..3e6e6083 100644 --- a/train.py +++ b/train.py @@ -100,6 +100,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch): mel_spec = mel_spec.cuda() mel_lengths = mel_lengths.cuda() linear_spec = linear_spec.cuda() + stop_target = stop_target.cuda() + # create attention mask if c.mk > 0.0: @@ -241,6 +243,7 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step): mel_spec = mel_spec.cuda() mel_lengths = mel_lengths.cuda() linear_spec = linear_spec.cuda() + stop_target = stop_target.cuda() # forward pass mel_output, linear_output, alignments, stop_tokens =\