diff --git a/tests/tts_tests/test_tacotron2_model.py b/tests/tts_tests/test_tacotron2_model.py index 77c291f7..ed79a26d 100644 --- a/tests/tts_tests/test_tacotron2_model.py +++ b/tests/tts_tests/test_tacotron2_model.py @@ -301,7 +301,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase): batch["stop_targets"] = (batch["stop_targets"].sum(2) > 0.0).unsqueeze(2).float().squeeze() model = Tacotron2(config).to(device) - criterion = model.get_criterion() + criterion = model.get_criterion().to(device) optimizer = model.get_optimizer() model.train()