Fix capacitron test when cuda is enabled

This commit is contained in:
WeberJulian 2022-12-06 18:07:48 +01:00
parent 9321b22203
commit 4787a2a993
1 changed files with 1 additions and 1 deletions

View File

@ -301,7 +301,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase):
batch["stop_targets"] = (batch["stop_targets"].sum(2) > 0.0).unsqueeze(2).float().squeeze()
model = Tacotron2(config).to(device)
criterion = model.get_criterion()
criterion = model.get_criterion().to(device)
optimizer = model.get_optimizer()
model.train()