From f3e9d61330dba326a0b528ee9d736eccf29f41d4 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 20 Jul 2021 19:51:35 +0530 Subject: [PATCH] Refactor logging initialization --- TTS/trainer.py | 37 +++++------------------------------ TTS/utils/logging/__init__.py | 21 ++++++++++++++++++++ 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/TTS/trainer.py b/TTS/trainer.py index 6acbe051..49f6a58c 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -38,7 +38,7 @@ from TTS.utils.generic_utils import ( to_cuda, ) from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint -from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger +from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_logger from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model as setup_vocoder_model @@ -160,28 +160,14 @@ class Trainer: self.output_path = output_path self.args = args self.config = config - + self.config.output_log_path = output_path # init loggers self.c_logger = ConsoleLogger() if c_logger is None else c_logger self.dashboard_logger = dashboard_logger if self.dashboard_logger is None: - if config.dashboard_logger == "tensorboard": - self.dashboard_logger = TensorboardLogger(output_path, model_name=config.model) + self.dashboard_logger = init_logger(config) - elif config.dashboard_logger == "wandb": - project_name = config.model - if config.project_name: - project_name = config.project_name - - self.dashboard_logger = WandbLogger( - project=project_name, - name=config.run_name, - config=config, - entity=config.wandb_entity, - ) - - self.dashboard_logger.add_text("model-config", f"
{config.to_json()}
", 0) if not self.config.log_model_step: self.config.log_model_step = self.config.save_step @@ -1129,6 +1115,7 @@ def process_args(args, config=None): if not experiment_path: experiment_path = get_experiment_folder_path(config.output_path, config.run_name) audio_path = os.path.join(experiment_path, "test_audios") + config.output_log_path = experiment_path # setup rank 0 process in distributed training dashboard_logger = None if args.rank == 0: @@ -1146,21 +1133,7 @@ def process_args(args, config=None): os.chmod(audio_path, 0o775) os.chmod(experiment_path, 0o775) - if config.dashboard_logger == "tensorboard": - dashboard_logger = TensorboardLogger(config.output_path, model_name=config.model) - - elif config.dashboard_logger == "wandb": - project_name = config.model - if config.project_name: - project_name = config.project_name - - dashboard_logger = WandbLogger( - project=project_name, - name=config.run_name, - config=config, - entity=config.wandb_entity, - ) - dashboard_logger.add_text("model-config", f"
{config.to_json()}
", 0) + dashboard_logger = init_logger(config) c_logger = ConsoleLogger() return config, experiment_path, audio_path, c_logger, dashboard_logger diff --git a/TTS/utils/logging/__init__.py b/TTS/utils/logging/__init__.py index a39bb912..4b92221f 100644 --- a/TTS/utils/logging/__init__.py +++ b/TTS/utils/logging/__init__.py @@ -1,3 +1,24 @@ from TTS.utils.logging.console_logger import ConsoleLogger from TTS.utils.logging.tensorboard_logger import TensorboardLogger from TTS.utils.logging.wandb_logger import WandbLogger + + +def init_logger(config): + if config.dashboard_logger == "tensorboard": + dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model) + + elif config.dashboard_logger == "wandb": + project_name = config.model + if config.project_name: + project_name = config.project_name + + dashboard_logger = WandbLogger( + project=project_name, + name=config.run_name, + config=config, + entity=config.wandb_entity, + ) + + dashboard_logger.add_text("model-config", f"
{config.to_json()}
", 0) + + return dashboard_logger