mirror of https://github.com/coqui-ai/TTS.git
do not call .item() if returned loss is not torch type
This commit is contained in:
parent
8f72ad900a
commit
3c20afa1c9
|
@ -217,6 +217,9 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
scheduler_D.step()
|
scheduler_D.step()
|
||||||
|
|
||||||
for key, value in loss_D_dict.items():
|
for key, value in loss_D_dict.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
loss_dict[key] = value
|
||||||
|
else:
|
||||||
loss_dict[key] = value.item()
|
loss_dict[key] = value.item()
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
|
@ -355,6 +358,9 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
||||||
|
|
||||||
loss_dict = dict()
|
loss_dict = dict()
|
||||||
for key, value in loss_G_dict.items():
|
for key, value in loss_G_dict.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
loss_dict[key] = value
|
||||||
|
else:
|
||||||
loss_dict[key] = value.item()
|
loss_dict[key] = value.item()
|
||||||
|
|
||||||
##############################
|
##############################
|
||||||
|
@ -393,6 +399,9 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
||||||
loss_D_dict = criterion_D(scores_fake, scores_real)
|
loss_D_dict = criterion_D(scores_fake, scores_real)
|
||||||
|
|
||||||
for key, value in loss_D_dict.items():
|
for key, value in loss_D_dict.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
loss_dict[key] = value
|
||||||
|
else:
|
||||||
loss_dict[key] = value.item()
|
loss_dict[key] = value.item()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue