diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index a5bf31b1..f7a69eb4 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -11,11 +11,9 @@ import traceback import numpy as np import torch from torch.utils.data import DataLoader - from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.layers.losses import TacotronLoss -from TTS.tts.utils.console_logger import ConsoleLogger from TTS.tts.utils.distribute import (DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor) @@ -28,6 +26,7 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor +from TTS.utils.console_logger import ConsoleLogger from TTS.utils.generic_utils import (KeepAverage, count_parameters, create_experiment_folder, get_git_branch, remove_experiment_folder, set_init_dict) diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index dc081a5e..23e6bd3b 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -4,13 +4,12 @@ import os import sys import time import traceback +from inspect import signature import torch from torch.utils.data import DataLoader - -from inspect import signature - from TTS.utils.audio import AudioProcessor +from TTS.utils.console_logger import ConsoleLogger from TTS.utils.generic_utils import (KeepAverage, count_parameters, create_experiment_folder, get_git_branch, remove_experiment_folder, set_init_dict) @@ -23,12 +22,10 @@ from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data # from distribute import (DistributedSampler, apply_gradient_allreduce, # init_distributed, reduce_tensor) from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss -from TTS.vocoder.utils.io import save_checkpoint, save_best_model -from TTS.vocoder.utils.console_logger import ConsoleLogger from TTS.vocoder.utils.generic_utils import (check_config, plot_results, setup_discriminator, setup_generator) - +from TTS.vocoder.utils.io import save_best_model, save_checkpoint use_cuda, num_gpus = setup_torch_training_env(True, True) @@ -238,10 +235,14 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, # print training stats if global_step % c.print_step == 0: + log_dict = { + 'step_time': [step_time, 2], + 'loader_time': [loader_time, 4], + "current_lr_G": current_lr_G, + "current_lr_D": current_lr_D + } c_logger.print_train_step(batch_n_iter, num_iter, global_step, - step_time, loader_time, current_lr_G, - current_lr_D, loss_dict, - keep_avg.avg_values) + log_dict, loss_dict, keep_avg.avg_values) # plot step stats if global_step % 10 == 0: diff --git a/TTS/tts/utils/console_logger.py b/TTS/utils/console_logger.py similarity index 100% rename from TTS/tts/utils/console_logger.py rename to TTS/utils/console_logger.py diff --git a/TTS/vocoder/utils/console_logger.py b/TTS/vocoder/utils/console_logger.py deleted file mode 100644 index b8908391..00000000 --- a/TTS/vocoder/utils/console_logger.py +++ /dev/null @@ -1,97 +0,0 @@ -import datetime -from TTS.utils.io import AttrDict - - -tcolors = AttrDict({ - 'OKBLUE': '\033[94m', - 'HEADER': '\033[95m', - 'OKGREEN': '\033[92m', - 'WARNING': '\033[93m', - 'FAIL': '\033[91m', - 'ENDC': '\033[0m', - 'BOLD': '\033[1m', - 'UNDERLINE': '\033[4m' -}) - - -class ConsoleLogger(): - # TODO: merge this with TTS ConsoleLogger - def __init__(self): - # use these to compare values between iterations - self.old_train_loss_dict = None - self.old_epoch_loss_dict = None - self.old_eval_loss_dict = None - - # pylint: disable=no-self-use - def get_time(self): - now = datetime.datetime.now() - return now.strftime("%Y-%m-%d %H:%M:%S") - - def print_epoch_start(self, epoch, max_epoch): - print("\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, - epoch, max_epoch, tcolors.ENDC), - flush=True) - - def print_train_start(self): - print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}") - - def print_train_step(self, batch_steps, step, global_step, log_dict, - step_time, loader_time, lrG, lrD, - loss_dict, avg_loss_dict): - indent = " | > " - print() - log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format( - tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC) - for key, value in loss_dict.items(): - # print the avg value if given - if f'avg_{key}' in avg_loss_dict.keys(): - log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}']) - else: - log_text += "{}{}: {:.5f} \n".format(indent, key, value) - log_text += f"{indent}step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lrG: {lrG}\n{indent}lrD: {lrD}" - print(log_text, flush=True) - - # pylint: disable=unused-argument - def print_train_epoch_end(self, global_step, epoch, epoch_time, - print_dict): - indent = " | > " - log_text = f"\n{tcolors.BOLD} --> TRAIN PERFORMACE -- EPOCH TIME: {epoch_time:.2f} sec -- GLOBAL_STEP: {global_step}{tcolors.ENDC}\n" - for key, value in print_dict.items(): - log_text += "{}{}: {:.5f}\n".format(indent, key, value) - print(log_text, flush=True) - - def print_eval_start(self): - print(f"{tcolors.BOLD} > EVALUATION {tcolors.ENDC}\n") - - def print_eval_step(self, step, loss_dict, avg_loss_dict): - indent = " | > " - print() - log_text = f"{tcolors.BOLD} --> STEP: {step}{tcolors.ENDC}\n" - for key, value in loss_dict.items(): - # print the avg value if given - if f'avg_{key}' in avg_loss_dict.keys(): - log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}']) - else: - log_text += "{}{}: {:.5f} \n".format(indent, key, value) - print(log_text, flush=True) - - def print_epoch_end(self, epoch, avg_loss_dict): - indent = " | > " - log_text = " {}--> EVAL PERFORMANCE{}\n".format( - tcolors.BOLD, tcolors.ENDC) - for key, value in avg_loss_dict.items(): - # print the avg value if given - color = '' - sign = '+' - diff = 0 - if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict: - diff = value - self.old_eval_loss_dict[key] - if diff < 0: - color = tcolors.OKGREEN - sign = '' - elif diff > 0: - color = tcolors.FAIL - sign = '+' - log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff) - self.old_eval_loss_dict = avg_loss_dict - print(log_text, flush=True)