From 84158c5e47ed915422f19c2b1adf302d6ddbe69a Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 28 Jul 2020 13:48:33 +0200 Subject: [PATCH] more generci console_logger --- TTS/bin/train_tts.py | 11 ++++++++--- TTS/tts/utils/console_logger.py | 21 ++++++++++++++------- TTS/vocoder/utils/console_logger.py | 2 +- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index ef14f627..a5bf31b1 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -218,10 +218,15 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, # print training progress if global_step % c.print_step == 0: + log_dict = { + "avg_spec_length": [avg_spec_length, 1], # value, precision + "avg_text_length": [avg_text_length, 1], + "step_time": [step_time, 4], + "loader_time": [loader_time, 2], + "current_lr": current_lr, + } c_logger.print_train_step(batch_n_iter, num_iter, global_step, - avg_spec_length, avg_text_length, - step_time, loader_time, current_lr, - loss_dict, keep_avg.avg_values) + log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # Plot Training Iter Stats diff --git a/TTS/tts/utils/console_logger.py b/TTS/tts/utils/console_logger.py index 85d5b376..3affd6af 100644 --- a/TTS/tts/utils/console_logger.py +++ b/TTS/tts/utils/console_logger.py @@ -35,8 +35,7 @@ class ConsoleLogger(): def print_train_start(self): print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}") - def print_train_step(self, batch_steps, step, global_step, avg_spec_length, - avg_text_length, step_time, loader_time, lr, + def print_train_step(self, batch_steps, step, global_step, log_dict, loss_dict, avg_loss_dict): indent = " | > " print() @@ -48,8 +47,13 @@ class ConsoleLogger(): log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}']) else: log_text += "{}{}: {:.5f} \n".format(indent, key, value) - log_text += f"{indent}avg_spec_len: {avg_spec_length}\n{indent}avg_text_len: {avg_text_length}\n{indent}"\ - f"step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lr: {lr:.5f}" + for idx, (key, value) in enumerate(log_dict.items()): + if isinstance(value, list): + log_text += f"{indent}{key}: {value[0]:.{value[1]}f}" + else: + log_text += f"{indent}{key}: {value}" + if idx < len(log_dict)-1: + log_text += "\n" print(log_text, flush=True) # pylint: disable=unused-argument @@ -82,14 +86,17 @@ class ConsoleLogger(): tcolors.BOLD, tcolors.ENDC) for key, value in avg_loss_dict.items(): # print the avg value if given - color = tcolors.FAIL + color = '' sign = '+' diff = 0 - if self.old_eval_loss_dict is not None: + if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict: diff = value - self.old_eval_loss_dict[key] - if diff <= 0: + if diff < 0: color = tcolors.OKGREEN sign = '' + elif diff > 0: + color = tcolors.FAIL + sign = '+' log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff) self.old_eval_loss_dict = avg_loss_dict print(log_text, flush=True) diff --git a/TTS/vocoder/utils/console_logger.py b/TTS/vocoder/utils/console_logger.py index 6af0b823..b8908391 100644 --- a/TTS/vocoder/utils/console_logger.py +++ b/TTS/vocoder/utils/console_logger.py @@ -35,7 +35,7 @@ class ConsoleLogger(): def print_train_start(self): print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}") - def print_train_step(self, batch_steps, step, global_step, + def print_train_step(self, batch_steps, step, global_step, log_dict, step_time, loader_time, lrG, lrD, loss_dict, avg_loss_dict): indent = " | > "