mirror of https://github.com/coqui-ai/TTS.git
more generci console_logger
This commit is contained in:
parent
effe54f262
commit
84158c5e47
|
@ -218,10 +218,15 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
|
||||
# print training progress
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {
|
||||
"avg_spec_length": [avg_spec_length, 1], # value, precision
|
||||
"avg_text_length": [avg_text_length, 1],
|
||||
"step_time": [step_time, 4],
|
||||
"loader_time": [loader_time, 2],
|
||||
"current_lr": current_lr,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||
avg_spec_length, avg_text_length,
|
||||
step_time, loader_time, current_lr,
|
||||
loss_dict, keep_avg.avg_values)
|
||||
log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
|
|
|
@ -35,8 +35,7 @@ class ConsoleLogger():
|
|||
def print_train_start(self):
|
||||
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
|
||||
|
||||
def print_train_step(self, batch_steps, step, global_step, avg_spec_length,
|
||||
avg_text_length, step_time, loader_time, lr,
|
||||
def print_train_step(self, batch_steps, step, global_step, log_dict,
|
||||
loss_dict, avg_loss_dict):
|
||||
indent = " | > "
|
||||
print()
|
||||
|
@ -48,8 +47,13 @@ class ConsoleLogger():
|
|||
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
|
||||
else:
|
||||
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
|
||||
log_text += f"{indent}avg_spec_len: {avg_spec_length}\n{indent}avg_text_len: {avg_text_length}\n{indent}"\
|
||||
f"step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lr: {lr:.5f}"
|
||||
for idx, (key, value) in enumerate(log_dict.items()):
|
||||
if isinstance(value, list):
|
||||
log_text += f"{indent}{key}: {value[0]:.{value[1]}f}"
|
||||
else:
|
||||
log_text += f"{indent}{key}: {value}"
|
||||
if idx < len(log_dict)-1:
|
||||
log_text += "\n"
|
||||
print(log_text, flush=True)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
|
@ -82,14 +86,17 @@ class ConsoleLogger():
|
|||
tcolors.BOLD, tcolors.ENDC)
|
||||
for key, value in avg_loss_dict.items():
|
||||
# print the avg value if given
|
||||
color = tcolors.FAIL
|
||||
color = ''
|
||||
sign = '+'
|
||||
diff = 0
|
||||
if self.old_eval_loss_dict is not None:
|
||||
if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict:
|
||||
diff = value - self.old_eval_loss_dict[key]
|
||||
if diff <= 0:
|
||||
if diff < 0:
|
||||
color = tcolors.OKGREEN
|
||||
sign = ''
|
||||
elif diff > 0:
|
||||
color = tcolors.FAIL
|
||||
sign = '+'
|
||||
log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff)
|
||||
self.old_eval_loss_dict = avg_loss_dict
|
||||
print(log_text, flush=True)
|
||||
|
|
|
@ -35,7 +35,7 @@ class ConsoleLogger():
|
|||
def print_train_start(self):
|
||||
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
|
||||
|
||||
def print_train_step(self, batch_steps, step, global_step,
|
||||
def print_train_step(self, batch_steps, step, global_step, log_dict,
|
||||
step_time, loader_time, lrG, lrD,
|
||||
loss_dict, avg_loss_dict):
|
||||
indent = " | > "
|
||||
|
|
Loading…
Reference in New Issue