more generci console_logger

This commit is contained in:
erogol 2020-07-28 13:48:33 +02:00
parent effe54f262
commit 84158c5e47
3 changed files with 23 additions and 11 deletions

View File

@ -218,10 +218,15 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
# print training progress
if global_step % c.print_step == 0:
log_dict = {
"avg_spec_length": [avg_spec_length, 1], # value, precision
"avg_text_length": [avg_text_length, 1],
"step_time": [step_time, 4],
"loader_time": [loader_time, 2],
"current_lr": current_lr,
}
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
avg_spec_length, avg_text_length,
step_time, loader_time, current_lr,
loss_dict, keep_avg.avg_values)
log_dict, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# Plot Training Iter Stats

View File

@ -35,8 +35,7 @@ class ConsoleLogger():
def print_train_start(self):
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
def print_train_step(self, batch_steps, step, global_step, avg_spec_length,
avg_text_length, step_time, loader_time, lr,
def print_train_step(self, batch_steps, step, global_step, log_dict,
loss_dict, avg_loss_dict):
indent = " | > "
print()
@ -48,8 +47,13 @@ class ConsoleLogger():
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
else:
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
log_text += f"{indent}avg_spec_len: {avg_spec_length}\n{indent}avg_text_len: {avg_text_length}\n{indent}"\
f"step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lr: {lr:.5f}"
for idx, (key, value) in enumerate(log_dict.items()):
if isinstance(value, list):
log_text += f"{indent}{key}: {value[0]:.{value[1]}f}"
else:
log_text += f"{indent}{key}: {value}"
if idx < len(log_dict)-1:
log_text += "\n"
print(log_text, flush=True)
# pylint: disable=unused-argument
@ -82,14 +86,17 @@ class ConsoleLogger():
tcolors.BOLD, tcolors.ENDC)
for key, value in avg_loss_dict.items():
# print the avg value if given
color = tcolors.FAIL
color = ''
sign = '+'
diff = 0
if self.old_eval_loss_dict is not None:
if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict:
diff = value - self.old_eval_loss_dict[key]
if diff <= 0:
if diff < 0:
color = tcolors.OKGREEN
sign = ''
elif diff > 0:
color = tcolors.FAIL
sign = '+'
log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff)
self.old_eval_loss_dict = avg_loss_dict
print(log_text, flush=True)

View File

@ -35,7 +35,7 @@ class ConsoleLogger():
def print_train_start(self):
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
def print_train_step(self, batch_steps, step, global_step,
def print_train_step(self, batch_steps, step, global_step, log_dict,
step_time, loader_time, lrG, lrD,
loss_dict, avg_loss_dict):
indent = " | > "