mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #2189 from coqui-ai/fix-capacitron-test
This commit is contained in:
commit
c753ad49cc
|
@ -301,7 +301,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase):
|
||||||
batch["stop_targets"] = (batch["stop_targets"].sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
batch["stop_targets"] = (batch["stop_targets"].sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
model = Tacotron2(config).to(device)
|
model = Tacotron2(config).to(device)
|
||||||
criterion = model.get_criterion()
|
criterion = model.get_criterion().to(device)
|
||||||
optimizer = model.get_optimizer()
|
optimizer = model.get_optimizer()
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
|
Loading…
Reference in New Issue