diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index a41e29a8..d419d50e 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -208,7 +208,7 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == "__main__": - args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger = init_training(sys.argv) + args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training(sys.argv) try: main(args) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index be8d6200..863bd3b9 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -5,8 +5,8 @@ from TTS.trainer import Trainer, init_training def main(): """Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```""" - args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(sys.argv) - trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger, cudnn_benchmark=False) + args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv) + trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=False) trainer.fit() diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index 1eac603e..000083e0 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -8,8 +8,8 @@ from TTS.utils.generic_utils import remove_experiment_folder def main(): try: - args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(sys.argv) - trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger) + args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv) + trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) trainer.fit() except KeyboardInterrupt: remove_experiment_folder(output_path) diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index 8d3b108e..ebc18e8c 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -225,15 +225,17 @@ class BaseTrainingConfig(Coqpit): print_step (int): Number of steps required to print the next training log. - tb_plot_step (int): + log_dashboard (str): "tensorboard" or "wandb" + Set the experiment tracking tool + + plot_step (int): Number of steps required to log training on Tensorboard. - tb_model_param_stats (bool): + model_param_stats (bool): Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging. Defaults to ```False```. - wandb_disabled: bool = False - wandb_project_name (str): + project_name (str): Name of the W&B project. Defaults to config.model wandb_entity (str): @@ -283,13 +285,13 @@ class BaseTrainingConfig(Coqpit): test_delay_epochs: int = 0 print_eval: bool = False # logging + dashboard_logger: str = "tensorboard" print_step: int = 25 - tb_plot_step: int = 100 - tb_model_param_stats: bool = False - wandb_disabled: bool = False - wandb_project_name: str = None + plot_step: int = 100 + model_param_stats: bool = False + project_name: str = None + log_model_step: int = None wandb_entity: str = None - wandb_log_model_step: int = None # checkpointing save_step: int = 10000 checkpoint: bool = True diff --git a/TTS/trainer.py b/TTS/trainer.py index 68628aed..8edb75cf 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -92,8 +92,7 @@ class Trainer: config: Coqpit, output_path: str, c_logger: ConsoleLogger = None, - tb_logger: TensorboardLogger = None, - wandb_logger: WandbLogger = None, + dashboard_logger: Union[TensorboardLogger, WandbLogger] = None, model: nn.Module = None, cudnn_benchmark: bool = False, ) -> None: @@ -118,10 +117,7 @@ class Trainer: c_logger (ConsoleLogger, optional): Console logger for printing training status. If not provided, the default console logger is used. Defaults to None. - tb_logger (TensorboardLogger, optional): Tensorboard logger. If not provided, the default logger is used. - Defaults to None. - - wandb_logger (WandbLogger, optional): W&B logger. If not provided, the default logger is used. + dashboard_logger Union[TensorboardLogger, WandbLogger]: Dashboard logger. If not provided, the tensorboard logger is used. Defaults to None. model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer` @@ -143,8 +139,8 @@ class Trainer: Running trainer on a config. >>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,) - >>> args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) - >>> trainer = Trainer(args, config, output_path, c_logger, tb_logger) + >>> args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) + >>> trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) >>> trainer.fit() TODO: @@ -159,7 +155,7 @@ class Trainer: self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark) if config is None: # parse config from console arguments - config, output_path, _, c_logger, tb_logger, wandb_logger = process_args(args) + config, output_path, _, c_logger, dashboard_logger = process_args(args) self.output_path = output_path self.args = args @@ -167,26 +163,27 @@ class Trainer: # init loggers self.c_logger = ConsoleLogger() if c_logger is None else c_logger - if tb_logger is None: - self.tb_logger = TensorboardLogger(output_path, model_name=config.model) - self.tb_logger.tb_add_text("model-config", f"
{config.to_json()}
", 0) - else: - self.tb_logger = tb_logger + self.dashboard_logger = dashboard_logger - if wandb_logger is None: - wandb_project_name = config.model - if config.wandb_project_name: - wandb_project_name = config.wandb_project_name + if self.dashboard_logger is None: + if config.dashboard_logger == "tensorboard": + self.dashboard_logger = TensorboardLogger(output_path, model_name=config.model) + self.dashboard_logger.add_text("model-config", f"
{config.to_json()}
", 0) + + elif config.dashboard_logger == "wandb": + project_name = config.model + if config.project_name: + project_name = config.project_name + + self.dashboard_logger = WandbLogger( + project=project_name, + name=config.run_name, + config=config, + entity=config.wandb_entity, + ) + if not self.config.log_model_step: + self.config.log_model_step = self.config.save_step - wandb_logger = WandbLogger( - disabled=config.wandb_disabled, - project=wandb_project_name, - name=config.run_name, - config=config, - entity=config.wandb_entity, - ) - else: - self.wandb_logger = wandb_logger log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") self._setup_logger_config(log_file) @@ -646,9 +643,8 @@ class Trainer: if self.args.rank == 0: # Plot Training Iter Stats # reduce TB load and don't log every step - if self.total_steps_done % self.config.tb_plot_step == 0: - self.tb_logger.tb_train_step_stats(self.total_steps_done, loss_dict) - self.wandb_logger.log(loss_dict, "train/", flush=True) + if self.total_steps_done % self.config.plot_step == 0: + self.dashboard_logger.train_step_stats(self.total_steps_done, loss_dict) if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0: if self.config.checkpoint: # checkpoint the model @@ -664,13 +660,10 @@ class Trainer: model_loss=target_avg_loss, ) - if ( - self.config.wandb_log_model_step - and self.total_steps_done % self.config.wandb_log_model_step == 0 - ): - # log checkpoint as W&B artifact + if self.total_steps_done % self.config.log_model_step == 0: + # log checkpoint as artifact aliases = [f"epoch-{self.epochs_done}", f"step-{self.total_steps_done}"] - self.wandb_logger.log_artifact(self.output_path, "checkpoint", "model", aliases) + self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases) # training visualizations figures, audios = None, None @@ -679,15 +672,13 @@ class Trainer: elif hasattr(self.model, "train_log"): figures, audios = self.model.train_log(self.ap, batch, outputs) if figures is not None: - self.tb_logger.tb_train_figures(self.total_steps_done, figures) - self.wandb_logger.log_figures(figures, "train/") + self.dashboard_logger.train_figures(self.total_steps_done, figures) if audios is not None: - self.tb_logger.tb_train_audios(self.total_steps_done, audios, self.ap.sample_rate) - self.wandb_logger.log_audios(audios, self.ap.sample_rate, "train/") + self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate) self.total_steps_done += 1 self.callbacks.on_train_step_end() - self.wandb_logger.flush() + self.dashboard_logger.flush() return outputs, loss_dict def train_epoch(self) -> None: @@ -712,10 +703,9 @@ class Trainer: if self.args.rank == 0: epoch_stats = {"epoch_time": epoch_time} epoch_stats.update(self.keep_avg_train.avg_values) - self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats) - self.wandb_logger.log_scalars(epoch_stats, "train/") - if self.config.tb_model_param_stats: - self.tb_logger.tb_model_weights(self.model, self.total_steps_done) + self.dashboard_logger.train_epoch_stats(self.total_steps_done, epoch_stats) + if self.config.model_param_stats: + self.logger.model_weights(self.model, self.total_steps_done) # scheduler step after the epoch if self.scheduler is not None and self.config.scheduler_after_epoch: if isinstance(self.scheduler, list): @@ -816,13 +806,10 @@ class Trainer: elif hasattr(self.model, "eval_log"): figures, audios = self.model.eval_log(self.ap, batch, outputs) if figures is not None: - self.tb_logger.tb_eval_figures(self.total_steps_done, figures) - self.wandb_logger.log_figures(figures, "eval/") + self.dashboard_logger.eval_figures(self.total_steps_done, figures) if audios is not None: - self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate) - self.wandb_logger.log_audios(audios, self.ap.sample_rate, "eval/") - self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values) - self.wandb_logger.log_scalars(self.keep_avg_eval.avg_values, "eval/") + self.dashboard_logger.eval_audios(self.total_steps_done, audios, self.ap.sample_rate) + self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values) def test_run(self) -> None: """Run test and log the results. Test run must be defined by the model. @@ -839,11 +826,9 @@ class Trainer: samples = self.eval_loader.dataset.load_test_samples(1) figures, audios = self.model.test_run(self.ap, samples, None) else: - figures, audios = self.model.test_run(self.ap) - self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) - self.tb_logger.tb_test_figures(self.total_steps_done, figures) - self.wandb_logger.log_audios(audios, self.config.audio["sample_rate"], "test/") - self.wandb_logger.log_figures(figures, "test/") + figures, audios = self.model.test_run() + self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) + self.dashboard_logger.test_figures(self.total_steps_done, figures) def _fit(self) -> None: """🏃 train -> evaluate -> test for the number of epochs.""" @@ -875,13 +860,13 @@ class Trainer: """Where the ✨️magic✨️ happens...""" try: self._fit() - self.wandb_logger.finish() + self.dashboard_logger.finish() except KeyboardInterrupt: self.callbacks.on_keyboard_interrupt() # if the output folder is empty remove the run. remove_experiment_folder(self.output_path) # finish the wandb run and sync data - self.wandb_logger.finish() + self.dashboard_logger.finish() # stop without error signal try: sys.exit(0) @@ -1108,9 +1093,8 @@ def process_args(args, config=None): audio_path (str): Path to save generated test audios. c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does logging to the console. - tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does - the TensorBoard logging. - wandb_logger (TTS.utils.tensorboard.WandbLogger): Class that does the W&B Logging + + dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging TODO: - Interactive config definition. @@ -1146,8 +1130,7 @@ def process_args(args, config=None): experiment_path = get_experiment_folder_path(config.output_path, config.run_name) audio_path = os.path.join(experiment_path, "test_audios") # setup rank 0 process in distributed training - tb_logger = None - wandb_logger = None + dashboard_logger = None if args.rank == 0: new_fields = {} if args.restore_path: @@ -1160,24 +1143,28 @@ def process_args(args, config=None): used_characters = parse_symbols() new_fields["characters"] = used_characters copy_model_files(config, experiment_path, new_fields) - tb_logger = TensorboardLogger(experiment_path, model_name=config.model) - # write model desc to tensorboard - tb_logger.tb_add_text("model-config", f"
{config.to_json()}
", 0) + os.chmod(audio_path, 0o775) + os.chmod(experiment_path, 0o775) - wandb_project_name = config.model - if config.wandb_project_name: - wandb_project_name = config.wandb_project_name + if config.dashboard_logger == "tensorboard": + dashboard_logger = TensorboardLogger(output_path, model_name=config.model) + dashboard_logger.add_text("model-config", f"
{config.to_json()}
", 0) + + elif config.dashboard_logger == "wandb": + project_name = config.model + if config.project_name: + project_name = config.project_name + + dashboard_logger = WandbLogger( + project=project_name, + name=config.run_name, + config=config, + entity=config.wandb_entity, + ) - wandb_logger = WandbLogger( - disabled=config.wandb_disabled, - project=wandb_project_name, - name=config.run_name, - config=config, - entity=config.wandb_entity, - ) c_logger = ConsoleLogger() - return config, experiment_path, audio_path, c_logger, tb_logger, wandb_logger + return config, experiment_path, audio_path, c_logger, dashboard_logger def init_arguments(): @@ -1193,5 +1180,5 @@ def init_training(argv: Union[List, Coqpit], config: Coqpit = None): else: parser = init_arguments() args = parser.parse_known_args() - config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger = process_args(args, config) - return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger + config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config) + return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger diff --git a/TTS/utils/logging/tensorboard_logger.py b/TTS/utils/logging/tensorboard_logger.py index 3d7ea1e6..ec57fc89 100644 --- a/TTS/utils/logging/tensorboard_logger.py +++ b/TTS/utils/logging/tensorboard_logger.py @@ -10,7 +10,7 @@ class TensorboardLogger(object): self.train_stats = {} self.eval_stats = {} - def tb_model_weights(self, model, step): + def model_weights(self, model, step): layer_num = 1 for name, param in model.named_parameters(): if param.numel() == 1: @@ -41,32 +41,41 @@ class TensorboardLogger(object): except RuntimeError: traceback.print_exc() - def tb_train_step_stats(self, step, stats): + def train_step_stats(self, step, stats): self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step) - def tb_train_epoch_stats(self, step, stats): + def train_epoch_stats(self, step, stats): self.dict_to_tb_scalar(f"{self.model_name}_TrainEpochStats", stats, step) - def tb_train_figures(self, step, figures): + def train_figures(self, step, figures): self.dict_to_tb_figure(f"{self.model_name}_TrainFigures", figures, step) - def tb_train_audios(self, step, audios, sample_rate): + def train_audios(self, step, audios, sample_rate): self.dict_to_tb_audios(f"{self.model_name}_TrainAudios", audios, step, sample_rate) - def tb_eval_stats(self, step, stats): + def eval_stats(self, step, stats): self.dict_to_tb_scalar(f"{self.model_name}_EvalStats", stats, step) - def tb_eval_figures(self, step, figures): + def eval_figures(self, step, figures): self.dict_to_tb_figure(f"{self.model_name}_EvalFigures", figures, step) - def tb_eval_audios(self, step, audios, sample_rate): + def eval_audios(self, step, audios, sample_rate): self.dict_to_tb_audios(f"{self.model_name}_EvalAudios", audios, step, sample_rate) - def tb_test_audios(self, step, audios, sample_rate): + def test_audios(self, step, audios, sample_rate): self.dict_to_tb_audios(f"{self.model_name}_TestAudios", audios, step, sample_rate) - def tb_test_figures(self, step, figures): + def test_figures(self, step, figures): self.dict_to_tb_figure(f"{self.model_name}_TestFigures", figures, step) - def tb_add_text(self, title, text, step): + def add_text(self, title, text, step): self.writer.add_text(title, text, step) + + def log_artifact(self, file_or_dir, name, artifact_type, aliases=None): + return + + def flush(self): + return + + def finish(self): + return diff --git a/TTS/utils/logging/wandb_logger.py b/TTS/utils/logging/wandb_logger.py index 6c09c2a3..45129a86 100644 --- a/TTS/utils/logging/wandb_logger.py +++ b/TTS/utils/logging/wandb_logger.py @@ -1,4 +1,5 @@ from pathlib import Path +import traceback try: import wandb @@ -8,40 +9,85 @@ except ImportError: class WandbLogger: - def __init__(self, disabled=False, **kwargs): + def __init__(self, **kwargs): + + if not wandb: + raise Exception("install wandb using `pip install wandb` to use WandbLogger") + self.run = None - if wandb and not disabled: - self.run = wandb.init(**kwargs) if not wandb.run else wandb.run + self.run = wandb.init(**kwargs) if not wandb.run else wandb.run + self.model_name = self.run.config.model self.log_dict = {} + def model_weights(self, model): + layer_num = 1 + for name, param in model.named_parameters(): + if param.numel() == 1: + self.dict_to_scalar("weights",{"layer{}-{}/value".format(layer_num, name): param.max()}) + else: + self.dict_to_scalar("weights", {"layer{}-{}/max".format(layer_num, name): param.max()}) + self.dict_to_scalar("weights", {"layer{}-{}/min".format(layer_num, name): param.min()}) + self.dict_to_scalar("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()}) + self.dict_to_scalar("weights", {"layer{}-{}/std".format(layer_num, name): param.std()}) + ''' + self.writer.add_histogram("layer{}-{}/param".format(layer_num, name), param, step) + self.writer.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step) + ''' + layer_num += 1 + + def dict_to_scalar(self, scope_name, stats): + for key, value in stats.items(): + self.log_dict["{}/{}".format(scope_name, key)] = value + + def dict_to_figure(self, scope_name, figures): + for key, value in figures.items(): + self.log_dict["{}/{}".format(scope_name, key)] = wandb.Image(value) + + def dict_to_audios(self, scope_name, audios, sample_rate): + for key, value in audios.items(): + if value.dtype == "float16": + value = value.astype("float32") + try: + self.log_dict["{}/{}".format(scope_name, key)] = wandb.Audio(value, sample_rate=sample_rate) + except RuntimeError: + traceback.print_exc() + + def log(self, log_dict, prefix="", flush=False): for key, value in log_dict.items(): self.log_dict[prefix + key] = value if flush: # for cases where you don't want to accumulate data self.flush() - def log_scalars(self, log_dict, prefix=""): - if not self.run: - return + def train_step_stats(self, step, stats): + self.dict_to_scalar(f"{self.model_name}_TrainIterStats", stats) - for key, value in log_dict.items(): - self.log_dict[prefix + key] = value + def train_epoch_stats(self, step, stats): + self.dict_to_scalar(f"{self.model_name}_TrainEpochStats", stats) - def log_audios(self, log_dict, sample_rate, prefix=""): - if not self.run: - return + def train_figures(self, step, figures): + self.dict_to_figure(f"{self.model_name}_TrainFigures", figures) - prefix = "audios/" + prefix - for key, value in log_dict.items(): - self.log_dict[prefix + key] = wandb.Audio(value, sample_rate=int(sample_rate)) + def train_audios(self, step, audios, sample_rate): + self.dict_to_audios(f"{self.model_name}_TrainAudios", audios, sample_rate) - def log_figures(self, log_dict, prefix=""): - if not self.run: - return + def eval_stats(self, step, stats): + self.dict_to_scalar(f"{self.model_name}_EvalStats", stats) - prefix = "figures/" + prefix - for key, value in log_dict.items(): - self.log_dict[prefix + key] = wandb.Image(value) + def eval_figures(self, step, figures): + self.dict_to_figure(f"{self.model_name}_EvalFigures", figures) + + def eval_audios(self, step, audios, sample_rate): + self.dict_to_audios(f"{self.model_name}_EvalAudios", audios, sample_rate) + + def test_audios(self, step, audios, sample_rate): + self.dict_to_audios(f"{self.model_name}_TestAudios", audios, sample_rate) + + def test_figures(self, step, figures): + self.dict_to_figure(f"{self.model_name}_TestFigures", figures) + + def add_text(self, title, text, step): + self.log_dict[title] = wandb.HTML(f"

{text}

") def flush(self): if self.run: