do not call .item() if returned loss is not torch type

This commit is contained in:
erogol 2020-06-12 20:55:06 +02:00
parent 8f72ad900a
commit 3c20afa1c9
1 changed files with 12 additions and 3 deletions

View File

@ -217,7 +217,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
scheduler_D.step()
for key, value in loss_D_dict.items():
loss_dict[key] = value.item()
if isinstance(value, (int, float)):
loss_dict[key] = value
else:
loss_dict[key] = value.item()
step_time = time.time() - start_time
epoch_time += step_time
@ -355,7 +358,10 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
loss_dict = dict()
for key, value in loss_G_dict.items():
loss_dict[key] = value.item()
if isinstance(value, (int, float)):
loss_dict[key] = value
else:
loss_dict[key] = value.item()
##############################
# DISCRIMINATOR
@ -393,7 +399,10 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
loss_D_dict = criterion_D(scores_fake, scores_real)
for key, value in loss_D_dict.items():
loss_dict[key] = value.item()
if isinstance(value, (int, float)):
loss_dict[key] = value
else:
loss_dict[key] = value.item()
step_time = time.time() - start_time