mirror of https://github.com/coqui-ai/TTS.git
Add artifacts logging , wandb args
This commit is contained in:
parent
f5e50ad502
commit
f606741dc4
|
@ -208,7 +208,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger = init_training(sys.argv)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -5,8 +5,8 @@ from TTS.trainer import Trainer, init_training
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```"""
|
"""Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```"""
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv)
|
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(sys.argv)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=False)
|
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger, cudnn_benchmark=False)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,8 @@ from TTS.utils.generic_utils import remove_experiment_folder
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
try:
|
try:
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv)
|
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(sys.argv)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
remove_experiment_folder(output_path)
|
remove_experiment_folder(output_path)
|
||||||
|
|
|
@ -231,6 +231,16 @@ class BaseTrainingConfig(Coqpit):
|
||||||
tb_model_param_stats (bool):
|
tb_model_param_stats (bool):
|
||||||
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
|
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
|
||||||
Defaults to ```False```.
|
Defaults to ```False```.
|
||||||
|
wandb_disabled: bool = False
|
||||||
|
|
||||||
|
wandb_project_name (str):
|
||||||
|
Name of the W&B project. Defaults to config.model
|
||||||
|
|
||||||
|
wandb_entity (str):
|
||||||
|
Name of W&B entity/team. Enables collaboration across a team or org.
|
||||||
|
|
||||||
|
wandb_log_model_step (int):
|
||||||
|
Number of steps required to log a checkpoint as W&B artifact
|
||||||
|
|
||||||
save_step (int):ipt
|
save_step (int):ipt
|
||||||
Number of steps required to save the next checkpoint.
|
Number of steps required to save the next checkpoint.
|
||||||
|
@ -276,6 +286,10 @@ class BaseTrainingConfig(Coqpit):
|
||||||
print_step: int = 25
|
print_step: int = 25
|
||||||
tb_plot_step: int = 100
|
tb_plot_step: int = 100
|
||||||
tb_model_param_stats: bool = False
|
tb_model_param_stats: bool = False
|
||||||
|
wandb_disabled: bool = False
|
||||||
|
wandb_project_name: str = None
|
||||||
|
wandb_entity: str = None
|
||||||
|
wandb_log_model_step: int = None
|
||||||
# checkpointing
|
# checkpointing
|
||||||
save_step: int = 10000
|
save_step: int = 10000
|
||||||
checkpoint: bool = True
|
checkpoint: bool = True
|
||||||
|
|
|
@ -159,7 +159,7 @@ class Trainer:
|
||||||
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
|
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
|
||||||
if config is None:
|
if config is None:
|
||||||
# parse config from console arguments
|
# parse config from console arguments
|
||||||
config, output_path, _, c_logger, tb_logger = process_args(args)
|
config, output_path, _, c_logger, tb_logger, wandb_logger = process_args(args)
|
||||||
|
|
||||||
self.output_path = output_path
|
self.output_path = output_path
|
||||||
self.args = args
|
self.args = args
|
||||||
|
@ -657,6 +657,15 @@ class Trainer:
|
||||||
self.output_path,
|
self.output_path,
|
||||||
model_loss=target_avg_loss,
|
model_loss=target_avg_loss,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config.wandb_log_model_step
|
||||||
|
and self.total_steps_done % self.config.wandb_log_model_step == 0
|
||||||
|
):
|
||||||
|
# log checkpoint as W&B artifact
|
||||||
|
aliases = [f"epoch-{self.epochs_done}", f"step-{self.total_steps_done}"]
|
||||||
|
self.wandb_logger.log_artifact(self.output_path, "checkpoint", "model", aliases)
|
||||||
|
|
||||||
# training visualizations
|
# training visualizations
|
||||||
figures, audios = None, None
|
figures, audios = None, None
|
||||||
if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"):
|
if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"):
|
||||||
|
@ -860,10 +869,13 @@ class Trainer:
|
||||||
"""Where the ✨️magic✨️ happens..."""
|
"""Where the ✨️magic✨️ happens..."""
|
||||||
try:
|
try:
|
||||||
self._fit()
|
self._fit()
|
||||||
|
self.wandb_logger.finish()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
self.callbacks.on_keyboard_interrupt()
|
self.callbacks.on_keyboard_interrupt()
|
||||||
# if the output folder is empty remove the run.
|
# if the output folder is empty remove the run.
|
||||||
remove_experiment_folder(self.output_path)
|
remove_experiment_folder(self.output_path)
|
||||||
|
# finish the wandb run and sync data
|
||||||
|
self.wandb_logger.finish()
|
||||||
# stop without error signal
|
# stop without error signal
|
||||||
try:
|
try:
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
@ -1092,7 +1104,7 @@ def process_args(args, config=None):
|
||||||
logging to the console.
|
logging to the console.
|
||||||
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
|
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
|
||||||
the TensorBoard logging.
|
the TensorBoard logging.
|
||||||
wandb_logger (TTS.utils.tensorboard.WandbLogger): Class that does the W&B Loggin
|
wandb_logger (TTS.utils.tensorboard.WandbLogger): Class that does the W&B Logging
|
||||||
|
|
||||||
TODO:
|
TODO:
|
||||||
- Interactive config definition.
|
- Interactive config definition.
|
||||||
|
@ -1106,17 +1118,15 @@ def process_args(args, config=None):
|
||||||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||||
if not args.best_path:
|
if not args.best_path:
|
||||||
args.best_path = best_model
|
args.best_path = best_model
|
||||||
# setup output paths and read configs
|
|
||||||
if config is None:
|
# init config if not already defined
|
||||||
config = load_config(args.config_path)
|
|
||||||
# init config
|
|
||||||
if config is None:
|
if config is None:
|
||||||
if args.config_path:
|
if args.config_path:
|
||||||
# init from a file
|
# init from a file
|
||||||
config = load_config(args.config_path)
|
config = load_config(args.config_path)
|
||||||
else:
|
else:
|
||||||
# init from console args
|
# init from console args
|
||||||
from TTS.config.shared_configs import BaseTrainingConfig
|
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
config_base = BaseTrainingConfig()
|
config_base = BaseTrainingConfig()
|
||||||
config_base.parse_known_args(coqpit_overrides)
|
config_base.parse_known_args(coqpit_overrides)
|
||||||
|
@ -1148,11 +1158,14 @@ def process_args(args, config=None):
|
||||||
# write model desc to tensorboard
|
# write model desc to tensorboard
|
||||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||||
|
|
||||||
wandb_logger = WandbLogger(
|
if not config.wandb_disabled:
|
||||||
project=config.model,
|
wandb_project_name = config.model
|
||||||
name=os.path.basename(experiment_path),
|
if config.wandb_project_name:
|
||||||
config=config,
|
wandb_project_name = config.wandb_project_name
|
||||||
)
|
|
||||||
|
wandb_logger = WandbLogger(
|
||||||
|
project=wandb_project_name, name=config.run_name, config=config, entity=config.wandb_entity
|
||||||
|
)
|
||||||
|
|
||||||
c_logger = ConsoleLogger()
|
c_logger = ConsoleLogger()
|
||||||
return config, experiment_path, audio_path, c_logger, tb_logger, wandb_logger
|
return config, experiment_path, audio_path, c_logger, tb_logger, wandb_logger
|
||||||
|
|
|
@ -2,7 +2,7 @@ from pathlib import Path
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import wandb
|
import wandb
|
||||||
from wandb import init, finish
|
from wandb import finish, init # pylint: disable=W0611
|
||||||
except ImportError:
|
except ImportError:
|
||||||
wandb = None
|
wandb = None
|
||||||
|
|
||||||
|
@ -15,10 +15,6 @@ class WandbLogger:
|
||||||
self.log_dict = {}
|
self.log_dict = {}
|
||||||
|
|
||||||
def log(self, log_dict, prefix="", flush=False):
|
def log(self, log_dict, prefix="", flush=False):
|
||||||
"""
|
|
||||||
This function accumulates data in self.log_dict. If flush is set.
|
|
||||||
the accumulated metrics will be logged directly to wandb dashboard.
|
|
||||||
"""
|
|
||||||
for key, value in log_dict.items():
|
for key, value in log_dict.items():
|
||||||
self.log_dict[prefix + key] = value
|
self.log_dict[prefix + key] = value
|
||||||
if flush: # for cases where you don't want to accumulate data
|
if flush: # for cases where you don't want to accumulate data
|
||||||
|
@ -53,13 +49,14 @@ class WandbLogger:
|
||||||
self.log_dict = {}
|
self.log_dict = {}
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
"""
|
if self.run:
|
||||||
Finish this W&B run
|
self.run.finish()
|
||||||
"""
|
|
||||||
self.run.finish()
|
|
||||||
|
|
||||||
def log_artifact(self, file_or_dir, name, type, aliases=[]):
|
def log_artifact(self, file_or_dir, name, artifact_type, aliases=None):
|
||||||
artifact = wandb.Artifact(name, type=type)
|
if not self.run:
|
||||||
|
return
|
||||||
|
name = "_".join([self.run.id, name])
|
||||||
|
artifact = wandb.Artifact(name, type=artifact_type)
|
||||||
data_path = Path(file_or_dir)
|
data_path = Path(file_or_dir)
|
||||||
if data_path.is_dir():
|
if data_path.is_dir():
|
||||||
artifact.add_dir(str(data_path))
|
artifact.add_dir(str(data_path))
|
||||||
|
|
Loading…
Reference in New Issue