mirror of https://github.com/coqui-ai/TTS.git
Update disabled structure
This commit is contained in:
parent
f606741dc4
commit
f4434da5a3
|
@ -174,10 +174,16 @@ class Trainer:
|
||||||
self.tb_logger = tb_logger
|
self.tb_logger = tb_logger
|
||||||
|
|
||||||
if wandb_logger is None:
|
if wandb_logger is None:
|
||||||
self.wandb_logger = WandbLogger(
|
wandb_project_name = config.model
|
||||||
project=os.path(output_path).stem,
|
if config.wandb_project_name:
|
||||||
name=config.model,
|
wandb_project_name = config.wandb_project_name
|
||||||
|
|
||||||
|
wandb_logger = WandbLogger(
|
||||||
|
disabled=config.wandb_disabled,
|
||||||
|
project=wandb_project_name,
|
||||||
|
name=config.run_name,
|
||||||
config=config,
|
config=config,
|
||||||
|
entity=config.wandb_entity,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.wandb_logger = wandb_logger
|
self.wandb_logger = wandb_logger
|
||||||
|
@ -1158,14 +1164,17 @@ def process_args(args, config=None):
|
||||||
# write model desc to tensorboard
|
# write model desc to tensorboard
|
||||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||||
|
|
||||||
if not config.wandb_disabled:
|
wandb_project_name = config.model
|
||||||
wandb_project_name = config.model
|
if config.wandb_project_name:
|
||||||
if config.wandb_project_name:
|
wandb_project_name = config.wandb_project_name
|
||||||
wandb_project_name = config.wandb_project_name
|
|
||||||
|
|
||||||
wandb_logger = WandbLogger(
|
wandb_logger = WandbLogger(
|
||||||
project=wandb_project_name, name=config.run_name, config=config, entity=config.wandb_entity
|
disabled=config.wandb_disabled,
|
||||||
)
|
project=wandb_project_name,
|
||||||
|
name=config.run_name,
|
||||||
|
config=config,
|
||||||
|
entity=config.wandb_entity,
|
||||||
|
)
|
||||||
|
|
||||||
c_logger = ConsoleLogger()
|
c_logger = ConsoleLogger()
|
||||||
return config, experiment_path, audio_path, c_logger, tb_logger, wandb_logger
|
return config, experiment_path, audio_path, c_logger, tb_logger, wandb_logger
|
||||||
|
|
|
@ -8,9 +8,9 @@ except ImportError:
|
||||||
|
|
||||||
|
|
||||||
class WandbLogger:
|
class WandbLogger:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, disabled=False, **kwargs):
|
||||||
self.run = None
|
self.run = None
|
||||||
if wandb:
|
if wandb and not disabled:
|
||||||
self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
|
self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
|
||||||
self.log_dict = {}
|
self.log_dict = {}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue