Unified logger API

This commit is contained in:
Ayush Chaurasia 2021-07-09 14:03:12 +05:30 committed by Eren Gölge
parent f4434da5a3
commit f63cf46c55
7 changed files with 168 additions and 124 deletions

View File

@ -208,7 +208,7 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == "__main__":
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger = init_training(sys.argv)
args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training(sys.argv)
try:
main(args)

View File

@ -5,8 +5,8 @@ from TTS.trainer import Trainer, init_training
def main():
"""Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```"""
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(sys.argv)
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger, cudnn_benchmark=False)
args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=False)
trainer.fit()

View File

@ -8,8 +8,8 @@ from TTS.utils.generic_utils import remove_experiment_folder
def main():
try:
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(sys.argv)
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
trainer.fit()
except KeyboardInterrupt:
remove_experiment_folder(output_path)

View File

@ -225,15 +225,17 @@ class BaseTrainingConfig(Coqpit):
print_step (int):
Number of steps required to print the next training log.
tb_plot_step (int):
log_dashboard (str): "tensorboard" or "wandb"
Set the experiment tracking tool
plot_step (int):
Number of steps required to log training on Tensorboard.
tb_model_param_stats (bool):
model_param_stats (bool):
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
Defaults to ```False```.
wandb_disabled: bool = False
wandb_project_name (str):
project_name (str):
Name of the W&B project. Defaults to config.model
wandb_entity (str):
@ -283,13 +285,13 @@ class BaseTrainingConfig(Coqpit):
test_delay_epochs: int = 0
print_eval: bool = False
# logging
dashboard_logger: str = "tensorboard"
print_step: int = 25
tb_plot_step: int = 100
tb_model_param_stats: bool = False
wandb_disabled: bool = False
wandb_project_name: str = None
plot_step: int = 100
model_param_stats: bool = False
project_name: str = None
log_model_step: int = None
wandb_entity: str = None
wandb_log_model_step: int = None
# checkpointing
save_step: int = 10000
checkpoint: bool = True

View File

@ -92,8 +92,7 @@ class Trainer:
config: Coqpit,
output_path: str,
c_logger: ConsoleLogger = None,
tb_logger: TensorboardLogger = None,
wandb_logger: WandbLogger = None,
dashboard_logger: Union[TensorboardLogger, WandbLogger] = None,
model: nn.Module = None,
cudnn_benchmark: bool = False,
) -> None:
@ -118,10 +117,7 @@ class Trainer:
c_logger (ConsoleLogger, optional): Console logger for printing training status. If not provided, the default
console logger is used. Defaults to None.
tb_logger (TensorboardLogger, optional): Tensorboard logger. If not provided, the default logger is used.
Defaults to None.
wandb_logger (WandbLogger, optional): W&B logger. If not provided, the default logger is used.
dashboard_logger Union[TensorboardLogger, WandbLogger]: Dashboard logger. If not provided, the tensorboard logger is used.
Defaults to None.
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
@ -143,8 +139,8 @@ class Trainer:
Running trainer on a config.
>>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,)
>>> args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
>>> trainer = Trainer(args, config, output_path, c_logger, tb_logger)
>>> args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
>>> trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
>>> trainer.fit()
TODO:
@ -159,7 +155,7 @@ class Trainer:
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
if config is None:
# parse config from console arguments
config, output_path, _, c_logger, tb_logger, wandb_logger = process_args(args)
config, output_path, _, c_logger, dashboard_logger = process_args(args)
self.output_path = output_path
self.args = args
@ -167,26 +163,27 @@ class Trainer:
# init loggers
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
if tb_logger is None:
self.tb_logger = TensorboardLogger(output_path, model_name=config.model)
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
else:
self.tb_logger = tb_logger
self.dashboard_logger = dashboard_logger
if wandb_logger is None:
wandb_project_name = config.model
if config.wandb_project_name:
wandb_project_name = config.wandb_project_name
if self.dashboard_logger is None:
if config.dashboard_logger == "tensorboard":
self.dashboard_logger = TensorboardLogger(output_path, model_name=config.model)
self.dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
elif config.dashboard_logger == "wandb":
project_name = config.model
if config.project_name:
project_name = config.project_name
self.dashboard_logger = WandbLogger(
project=project_name,
name=config.run_name,
config=config,
entity=config.wandb_entity,
)
if not self.config.log_model_step:
self.config.log_model_step = self.config.save_step
wandb_logger = WandbLogger(
disabled=config.wandb_disabled,
project=wandb_project_name,
name=config.run_name,
config=config,
entity=config.wandb_entity,
)
else:
self.wandb_logger = wandb_logger
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file)
@ -646,9 +643,8 @@ class Trainer:
if self.args.rank == 0:
# Plot Training Iter Stats
# reduce TB load and don't log every step
if self.total_steps_done % self.config.tb_plot_step == 0:
self.tb_logger.tb_train_step_stats(self.total_steps_done, loss_dict)
self.wandb_logger.log(loss_dict, "train/", flush=True)
if self.total_steps_done % self.config.plot_step == 0:
self.dashboard_logger.train_step_stats(self.total_steps_done, loss_dict)
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
if self.config.checkpoint:
# checkpoint the model
@ -664,13 +660,10 @@ class Trainer:
model_loss=target_avg_loss,
)
if (
self.config.wandb_log_model_step
and self.total_steps_done % self.config.wandb_log_model_step == 0
):
# log checkpoint as W&B artifact
if self.total_steps_done % self.config.log_model_step == 0:
# log checkpoint as artifact
aliases = [f"epoch-{self.epochs_done}", f"step-{self.total_steps_done}"]
self.wandb_logger.log_artifact(self.output_path, "checkpoint", "model", aliases)
self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases)
# training visualizations
figures, audios = None, None
@ -679,15 +672,13 @@ class Trainer:
elif hasattr(self.model, "train_log"):
figures, audios = self.model.train_log(self.ap, batch, outputs)
if figures is not None:
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
self.wandb_logger.log_figures(figures, "train/")
self.dashboard_logger.train_figures(self.total_steps_done, figures)
if audios is not None:
self.tb_logger.tb_train_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.wandb_logger.log_audios(audios, self.ap.sample_rate, "train/")
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.total_steps_done += 1
self.callbacks.on_train_step_end()
self.wandb_logger.flush()
self.dashboard_logger.flush()
return outputs, loss_dict
def train_epoch(self) -> None:
@ -712,10 +703,9 @@ class Trainer:
if self.args.rank == 0:
epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(self.keep_avg_train.avg_values)
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
self.wandb_logger.log_scalars(epoch_stats, "train/")
if self.config.tb_model_param_stats:
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
self.dashboard_logger.train_epoch_stats(self.total_steps_done, epoch_stats)
if self.config.model_param_stats:
self.logger.model_weights(self.model, self.total_steps_done)
# scheduler step after the epoch
if self.scheduler is not None and self.config.scheduler_after_epoch:
if isinstance(self.scheduler, list):
@ -816,13 +806,10 @@ class Trainer:
elif hasattr(self.model, "eval_log"):
figures, audios = self.model.eval_log(self.ap, batch, outputs)
if figures is not None:
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
self.wandb_logger.log_figures(figures, "eval/")
self.dashboard_logger.eval_figures(self.total_steps_done, figures)
if audios is not None:
self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.wandb_logger.log_audios(audios, self.ap.sample_rate, "eval/")
self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
self.wandb_logger.log_scalars(self.keep_avg_eval.avg_values, "eval/")
self.dashboard_logger.eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
def test_run(self) -> None:
"""Run test and log the results. Test run must be defined by the model.
@ -839,11 +826,9 @@ class Trainer:
samples = self.eval_loader.dataset.load_test_samples(1)
figures, audios = self.model.test_run(self.ap, samples, None)
else:
figures, audios = self.model.test_run(self.ap)
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
self.wandb_logger.log_audios(audios, self.config.audio["sample_rate"], "test/")
self.wandb_logger.log_figures(figures, "test/")
figures, audios = self.model.test_run()
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.dashboard_logger.test_figures(self.total_steps_done, figures)
def _fit(self) -> None:
"""🏃 train -> evaluate -> test for the number of epochs."""
@ -875,13 +860,13 @@ class Trainer:
"""Where the ✨magic✨ happens..."""
try:
self._fit()
self.wandb_logger.finish()
self.dashboard_logger.finish()
except KeyboardInterrupt:
self.callbacks.on_keyboard_interrupt()
# if the output folder is empty remove the run.
remove_experiment_folder(self.output_path)
# finish the wandb run and sync data
self.wandb_logger.finish()
self.dashboard_logger.finish()
# stop without error signal
try:
sys.exit(0)
@ -1108,9 +1093,8 @@ def process_args(args, config=None):
audio_path (str): Path to save generated test audios.
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
logging to the console.
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
the TensorBoard logging.
wandb_logger (TTS.utils.tensorboard.WandbLogger): Class that does the W&B Logging
dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
TODO:
- Interactive config definition.
@ -1146,8 +1130,7 @@ def process_args(args, config=None):
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
audio_path = os.path.join(experiment_path, "test_audios")
# setup rank 0 process in distributed training
tb_logger = None
wandb_logger = None
dashboard_logger = None
if args.rank == 0:
new_fields = {}
if args.restore_path:
@ -1160,24 +1143,28 @@ def process_args(args, config=None):
used_characters = parse_symbols()
new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields)
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
# write model desc to tensorboard
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
os.chmod(audio_path, 0o775)
os.chmod(experiment_path, 0o775)
wandb_project_name = config.model
if config.wandb_project_name:
wandb_project_name = config.wandb_project_name
if config.dashboard_logger == "tensorboard":
dashboard_logger = TensorboardLogger(output_path, model_name=config.model)
dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
elif config.dashboard_logger == "wandb":
project_name = config.model
if config.project_name:
project_name = config.project_name
dashboard_logger = WandbLogger(
project=project_name,
name=config.run_name,
config=config,
entity=config.wandb_entity,
)
wandb_logger = WandbLogger(
disabled=config.wandb_disabled,
project=wandb_project_name,
name=config.run_name,
config=config,
entity=config.wandb_entity,
)
c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, tb_logger, wandb_logger
return config, experiment_path, audio_path, c_logger, dashboard_logger
def init_arguments():
@ -1193,5 +1180,5 @@ def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
else:
parser = init_arguments()
args = parser.parse_known_args()
config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger = process_args(args, config)
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger
config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config)
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger

View File

@ -10,7 +10,7 @@ class TensorboardLogger(object):
self.train_stats = {}
self.eval_stats = {}
def tb_model_weights(self, model, step):
def model_weights(self, model, step):
layer_num = 1
for name, param in model.named_parameters():
if param.numel() == 1:
@ -41,32 +41,41 @@ class TensorboardLogger(object):
except RuntimeError:
traceback.print_exc()
def tb_train_step_stats(self, step, stats):
def train_step_stats(self, step, stats):
self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step)
def tb_train_epoch_stats(self, step, stats):
def train_epoch_stats(self, step, stats):
self.dict_to_tb_scalar(f"{self.model_name}_TrainEpochStats", stats, step)
def tb_train_figures(self, step, figures):
def train_figures(self, step, figures):
self.dict_to_tb_figure(f"{self.model_name}_TrainFigures", figures, step)
def tb_train_audios(self, step, audios, sample_rate):
def train_audios(self, step, audios, sample_rate):
self.dict_to_tb_audios(f"{self.model_name}_TrainAudios", audios, step, sample_rate)
def tb_eval_stats(self, step, stats):
def eval_stats(self, step, stats):
self.dict_to_tb_scalar(f"{self.model_name}_EvalStats", stats, step)
def tb_eval_figures(self, step, figures):
def eval_figures(self, step, figures):
self.dict_to_tb_figure(f"{self.model_name}_EvalFigures", figures, step)
def tb_eval_audios(self, step, audios, sample_rate):
def eval_audios(self, step, audios, sample_rate):
self.dict_to_tb_audios(f"{self.model_name}_EvalAudios", audios, step, sample_rate)
def tb_test_audios(self, step, audios, sample_rate):
def test_audios(self, step, audios, sample_rate):
self.dict_to_tb_audios(f"{self.model_name}_TestAudios", audios, step, sample_rate)
def tb_test_figures(self, step, figures):
def test_figures(self, step, figures):
self.dict_to_tb_figure(f"{self.model_name}_TestFigures", figures, step)
def tb_add_text(self, title, text, step):
def add_text(self, title, text, step):
self.writer.add_text(title, text, step)
def log_artifact(self, file_or_dir, name, artifact_type, aliases=None):
return
def flush(self):
return
def finish(self):
return

View File

@ -1,4 +1,5 @@
from pathlib import Path
import traceback
try:
import wandb
@ -8,40 +9,85 @@ except ImportError:
class WandbLogger:
def __init__(self, disabled=False, **kwargs):
def __init__(self, **kwargs):
if not wandb:
raise Exception("install wandb using `pip install wandb` to use WandbLogger")
self.run = None
if wandb and not disabled:
self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
self.model_name = self.run.config.model
self.log_dict = {}
def model_weights(self, model):
layer_num = 1
for name, param in model.named_parameters():
if param.numel() == 1:
self.dict_to_scalar("weights",{"layer{}-{}/value".format(layer_num, name): param.max()})
else:
self.dict_to_scalar("weights", {"layer{}-{}/max".format(layer_num, name): param.max()})
self.dict_to_scalar("weights", {"layer{}-{}/min".format(layer_num, name): param.min()})
self.dict_to_scalar("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()})
self.dict_to_scalar("weights", {"layer{}-{}/std".format(layer_num, name): param.std()})
'''
self.writer.add_histogram("layer{}-{}/param".format(layer_num, name), param, step)
self.writer.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step)
'''
layer_num += 1
def dict_to_scalar(self, scope_name, stats):
for key, value in stats.items():
self.log_dict["{}/{}".format(scope_name, key)] = value
def dict_to_figure(self, scope_name, figures):
for key, value in figures.items():
self.log_dict["{}/{}".format(scope_name, key)] = wandb.Image(value)
def dict_to_audios(self, scope_name, audios, sample_rate):
for key, value in audios.items():
if value.dtype == "float16":
value = value.astype("float32")
try:
self.log_dict["{}/{}".format(scope_name, key)] = wandb.Audio(value, sample_rate=sample_rate)
except RuntimeError:
traceback.print_exc()
def log(self, log_dict, prefix="", flush=False):
for key, value in log_dict.items():
self.log_dict[prefix + key] = value
if flush: # for cases where you don't want to accumulate data
self.flush()
def log_scalars(self, log_dict, prefix=""):
if not self.run:
return
def train_step_stats(self, step, stats):
self.dict_to_scalar(f"{self.model_name}_TrainIterStats", stats)
for key, value in log_dict.items():
self.log_dict[prefix + key] = value
def train_epoch_stats(self, step, stats):
self.dict_to_scalar(f"{self.model_name}_TrainEpochStats", stats)
def log_audios(self, log_dict, sample_rate, prefix=""):
if not self.run:
return
def train_figures(self, step, figures):
self.dict_to_figure(f"{self.model_name}_TrainFigures", figures)
prefix = "audios/" + prefix
for key, value in log_dict.items():
self.log_dict[prefix + key] = wandb.Audio(value, sample_rate=int(sample_rate))
def train_audios(self, step, audios, sample_rate):
self.dict_to_audios(f"{self.model_name}_TrainAudios", audios, sample_rate)
def log_figures(self, log_dict, prefix=""):
if not self.run:
return
def eval_stats(self, step, stats):
self.dict_to_scalar(f"{self.model_name}_EvalStats", stats)
prefix = "figures/" + prefix
for key, value in log_dict.items():
self.log_dict[prefix + key] = wandb.Image(value)
def eval_figures(self, step, figures):
self.dict_to_figure(f"{self.model_name}_EvalFigures", figures)
def eval_audios(self, step, audios, sample_rate):
self.dict_to_audios(f"{self.model_name}_EvalAudios", audios, sample_rate)
def test_audios(self, step, audios, sample_rate):
self.dict_to_audios(f"{self.model_name}_TestAudios", audios, sample_rate)
def test_figures(self, step, figures):
self.dict_to_figure(f"{self.model_name}_TestFigures", figures)
def add_text(self, title, text, step):
self.log_dict[title] = wandb.HTML(f"<p> {text} </p>")
def flush(self):
if self.run: