mirror of https://github.com/coqui-ai/TTS.git
Fix capacitron test when cuda is enabled
This commit is contained in:
parent
9321b22203
commit
4787a2a993
|
@ -301,7 +301,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase):
|
||||||
batch["stop_targets"] = (batch["stop_targets"].sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
batch["stop_targets"] = (batch["stop_targets"].sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
model = Tacotron2(config).to(device)
|
model = Tacotron2(config).to(device)
|
||||||
criterion = model.get_criterion()
|
criterion = model.get_criterion().to(device)
|
||||||
optimizer = model.get_optimizer()
|
optimizer = model.get_optimizer()
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
|
Loading…
Reference in New Issue