diff --git a/TTS/trainer.py b/TTS/trainer.py index 48ea92b5..68628aed 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -174,10 +174,16 @@ class Trainer: self.tb_logger = tb_logger if wandb_logger is None: - self.wandb_logger = WandbLogger( - project=os.path(output_path).stem, - name=config.model, + wandb_project_name = config.model + if config.wandb_project_name: + wandb_project_name = config.wandb_project_name + + wandb_logger = WandbLogger( + disabled=config.wandb_disabled, + project=wandb_project_name, + name=config.run_name, config=config, + entity=config.wandb_entity, ) else: self.wandb_logger = wandb_logger @@ -1158,14 +1164,17 @@ def process_args(args, config=None): # write model desc to tensorboard tb_logger.tb_add_text("model-config", f"
{config.to_json()}
", 0) - if not config.wandb_disabled: - wandb_project_name = config.model - if config.wandb_project_name: - wandb_project_name = config.wandb_project_name + wandb_project_name = config.model + if config.wandb_project_name: + wandb_project_name = config.wandb_project_name - wandb_logger = WandbLogger( - project=wandb_project_name, name=config.run_name, config=config, entity=config.wandb_entity - ) + wandb_logger = WandbLogger( + disabled=config.wandb_disabled, + project=wandb_project_name, + name=config.run_name, + config=config, + entity=config.wandb_entity, + ) c_logger = ConsoleLogger() return config, experiment_path, audio_path, c_logger, tb_logger, wandb_logger diff --git a/TTS/utils/logging/wandb_logger.py b/TTS/utils/logging/wandb_logger.py index e8f6765b..6c09c2a3 100644 --- a/TTS/utils/logging/wandb_logger.py +++ b/TTS/utils/logging/wandb_logger.py @@ -8,9 +8,9 @@ except ImportError: class WandbLogger: - def __init__(self, **kwargs): + def __init__(self, disabled=False, **kwargs): self.run = None - if wandb: + if wandb and not disabled: self.run = wandb.init(**kwargs) if not wandb.run else wandb.run self.log_dict = {}