From 4787a2a993c6c0f72b214732fb8c037ca3f3bf6f Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Tue, 6 Dec 2022 18:07:48 +0100 Subject: [PATCH] Fix capacitron test when cuda is enabled --- tests/tts_tests/test_tacotron2_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tts_tests/test_tacotron2_model.py b/tests/tts_tests/test_tacotron2_model.py index 77c291f7..ed79a26d 100644 --- a/tests/tts_tests/test_tacotron2_model.py +++ b/tests/tts_tests/test_tacotron2_model.py @@ -301,7 +301,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase): batch["stop_targets"] = (batch["stop_targets"].sum(2) > 0.0).unsqueeze(2).float().squeeze() model = Tacotron2(config).to(device) - criterion = model.get_criterion() + criterion = model.get_criterion().to(device) optimizer = model.get_optimizer() model.train()