Don't OOR values in train console log

This commit is contained in:
Eren Gölge 2021-10-19 16:32:16 +00:00
parent c514351c0e
commit 3c7848e9b1
1 changed files with 11 additions and 3 deletions

View File

@ -47,11 +47,19 @@ class ConsoleLogger:
tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC
)
for key, value in loss_dict.items():
# print the avg value if given
if f"avg_{key}" in avg_loss_dict.keys():
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"])
# print the avg value if given
if isinstance(value, float) and round(value, 5) == 0:
# do not round the number if it is zero when rounded
log_text += "{}{}: {} ({})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"])
else:
# print the rounded value
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"])
else:
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
if isinstance(value, float) and round(value, 5) == 0:
log_text += "{}{}: {} \n".format(indent, key, value)
else:
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
print(log_text, flush=True)
# pylint: disable=unused-argument