mirror of https://github.com/coqui-ai/TTS.git
Refactor logging initialization
This commit is contained in:
parent
79b74a989d
commit
f3e9d61330
|
@ -38,7 +38,7 @@ from TTS.utils.generic_utils import (
|
||||||
to_cuda,
|
to_cuda,
|
||||||
)
|
)
|
||||||
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
||||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger
|
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_logger
|
||||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||||
|
@ -160,28 +160,14 @@ class Trainer:
|
||||||
self.output_path = output_path
|
self.output_path = output_path
|
||||||
self.args = args
|
self.args = args
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.config.output_log_path = output_path
|
||||||
# init loggers
|
# init loggers
|
||||||
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
|
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
|
||||||
self.dashboard_logger = dashboard_logger
|
self.dashboard_logger = dashboard_logger
|
||||||
|
|
||||||
if self.dashboard_logger is None:
|
if self.dashboard_logger is None:
|
||||||
if config.dashboard_logger == "tensorboard":
|
self.dashboard_logger = init_logger(config)
|
||||||
self.dashboard_logger = TensorboardLogger(output_path, model_name=config.model)
|
|
||||||
|
|
||||||
elif config.dashboard_logger == "wandb":
|
|
||||||
project_name = config.model
|
|
||||||
if config.project_name:
|
|
||||||
project_name = config.project_name
|
|
||||||
|
|
||||||
self.dashboard_logger = WandbLogger(
|
|
||||||
project=project_name,
|
|
||||||
name=config.run_name,
|
|
||||||
config=config,
|
|
||||||
entity=config.wandb_entity,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
|
||||||
if not self.config.log_model_step:
|
if not self.config.log_model_step:
|
||||||
self.config.log_model_step = self.config.save_step
|
self.config.log_model_step = self.config.save_step
|
||||||
|
|
||||||
|
@ -1129,6 +1115,7 @@ def process_args(args, config=None):
|
||||||
if not experiment_path:
|
if not experiment_path:
|
||||||
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
|
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
|
||||||
audio_path = os.path.join(experiment_path, "test_audios")
|
audio_path = os.path.join(experiment_path, "test_audios")
|
||||||
|
config.output_log_path = experiment_path
|
||||||
# setup rank 0 process in distributed training
|
# setup rank 0 process in distributed training
|
||||||
dashboard_logger = None
|
dashboard_logger = None
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
|
@ -1146,21 +1133,7 @@ def process_args(args, config=None):
|
||||||
os.chmod(audio_path, 0o775)
|
os.chmod(audio_path, 0o775)
|
||||||
os.chmod(experiment_path, 0o775)
|
os.chmod(experiment_path, 0o775)
|
||||||
|
|
||||||
if config.dashboard_logger == "tensorboard":
|
dashboard_logger = init_logger(config)
|
||||||
dashboard_logger = TensorboardLogger(config.output_path, model_name=config.model)
|
|
||||||
|
|
||||||
elif config.dashboard_logger == "wandb":
|
|
||||||
project_name = config.model
|
|
||||||
if config.project_name:
|
|
||||||
project_name = config.project_name
|
|
||||||
|
|
||||||
dashboard_logger = WandbLogger(
|
|
||||||
project=project_name,
|
|
||||||
name=config.run_name,
|
|
||||||
config=config,
|
|
||||||
entity=config.wandb_entity,
|
|
||||||
)
|
|
||||||
dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
|
||||||
c_logger = ConsoleLogger()
|
c_logger = ConsoleLogger()
|
||||||
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,24 @@
|
||||||
from TTS.utils.logging.console_logger import ConsoleLogger
|
from TTS.utils.logging.console_logger import ConsoleLogger
|
||||||
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
||||||
from TTS.utils.logging.wandb_logger import WandbLogger
|
from TTS.utils.logging.wandb_logger import WandbLogger
|
||||||
|
|
||||||
|
|
||||||
|
def init_logger(config):
|
||||||
|
if config.dashboard_logger == "tensorboard":
|
||||||
|
dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model)
|
||||||
|
|
||||||
|
elif config.dashboard_logger == "wandb":
|
||||||
|
project_name = config.model
|
||||||
|
if config.project_name:
|
||||||
|
project_name = config.project_name
|
||||||
|
|
||||||
|
dashboard_logger = WandbLogger(
|
||||||
|
project=project_name,
|
||||||
|
name=config.run_name,
|
||||||
|
config=config,
|
||||||
|
entity=config.wandb_entity,
|
||||||
|
)
|
||||||
|
|
||||||
|
dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||||
|
|
||||||
|
return dashboard_logger
|
||||||
|
|
Loading…
Reference in New Issue