Refactor logging initialization

This commit is contained in:
Ayush Chaurasia 2021-07-20 19:51:35 +05:30 committed by Eren Gölge
parent 79b74a989d
commit f3e9d61330
2 changed files with 26 additions and 32 deletions

View File

@ -38,7 +38,7 @@ from TTS.utils.generic_utils import (
to_cuda, to_cuda,
) )
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_logger
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.models import setup_model as setup_vocoder_model
@ -160,28 +160,14 @@ class Trainer:
self.output_path = output_path self.output_path = output_path
self.args = args self.args = args
self.config = config self.config = config
self.config.output_log_path = output_path
# init loggers # init loggers
self.c_logger = ConsoleLogger() if c_logger is None else c_logger self.c_logger = ConsoleLogger() if c_logger is None else c_logger
self.dashboard_logger = dashboard_logger self.dashboard_logger = dashboard_logger
if self.dashboard_logger is None: if self.dashboard_logger is None:
if config.dashboard_logger == "tensorboard": self.dashboard_logger = init_logger(config)
self.dashboard_logger = TensorboardLogger(output_path, model_name=config.model)
elif config.dashboard_logger == "wandb":
project_name = config.model
if config.project_name:
project_name = config.project_name
self.dashboard_logger = WandbLogger(
project=project_name,
name=config.run_name,
config=config,
entity=config.wandb_entity,
)
self.dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
if not self.config.log_model_step: if not self.config.log_model_step:
self.config.log_model_step = self.config.save_step self.config.log_model_step = self.config.save_step
@ -1129,6 +1115,7 @@ def process_args(args, config=None):
if not experiment_path: if not experiment_path:
experiment_path = get_experiment_folder_path(config.output_path, config.run_name) experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
audio_path = os.path.join(experiment_path, "test_audios") audio_path = os.path.join(experiment_path, "test_audios")
config.output_log_path = experiment_path
# setup rank 0 process in distributed training # setup rank 0 process in distributed training
dashboard_logger = None dashboard_logger = None
if args.rank == 0: if args.rank == 0:
@ -1146,21 +1133,7 @@ def process_args(args, config=None):
os.chmod(audio_path, 0o775) os.chmod(audio_path, 0o775)
os.chmod(experiment_path, 0o775) os.chmod(experiment_path, 0o775)
if config.dashboard_logger == "tensorboard": dashboard_logger = init_logger(config)
dashboard_logger = TensorboardLogger(config.output_path, model_name=config.model)
elif config.dashboard_logger == "wandb":
project_name = config.model
if config.project_name:
project_name = config.project_name
dashboard_logger = WandbLogger(
project=project_name,
name=config.run_name,
config=config,
entity=config.wandb_entity,
)
dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
c_logger = ConsoleLogger() c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, dashboard_logger return config, experiment_path, audio_path, c_logger, dashboard_logger

View File

@ -1,3 +1,24 @@
from TTS.utils.logging.console_logger import ConsoleLogger from TTS.utils.logging.console_logger import ConsoleLogger
from TTS.utils.logging.tensorboard_logger import TensorboardLogger from TTS.utils.logging.tensorboard_logger import TensorboardLogger
from TTS.utils.logging.wandb_logger import WandbLogger from TTS.utils.logging.wandb_logger import WandbLogger
def init_logger(config):
if config.dashboard_logger == "tensorboard":
dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model)
elif config.dashboard_logger == "wandb":
project_name = config.model
if config.project_name:
project_name = config.project_name
dashboard_logger = WandbLogger(
project=project_name,
name=config.run_name,
config=config,
entity=config.wandb_entity,
)
dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
return dashboard_logger