From f5e50ad502c9abd651a938fce3061f29f2c50639 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Mon, 5 Jul 2021 06:42:36 +0000 Subject: [PATCH] WandbLogger --- TTS/trainer.py | 50 ++++++++++++-- TTS/utils/logging/__init__.py | 1 + TTS/utils/logging/wandb_logger.py | 69 +++++++++++++++++++ recipes/ljspeech/align_tts/train_aligntts.py | 4 +- recipes/ljspeech/glow_tts/train_glowtts.py | 4 +- recipes/ljspeech/hifigan/train_hifigan.py | 4 +- .../train_multiband_melgan.py | 4 +- recipes/ljspeech/univnet/train.py | 4 +- recipes/ljspeech/wavegrad/train_wavegrad.py | 4 +- recipes/ljspeech/wavernn/train_wavernn.py | 4 +- 10 files changed, 128 insertions(+), 20 deletions(-) create mode 100644 TTS/utils/logging/wandb_logger.py diff --git a/TTS/trainer.py b/TTS/trainer.py index a3e87e67..d39fd747 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -38,7 +38,7 @@ from TTS.utils.generic_utils import ( to_cuda, ) from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint -from TTS.utils.logging import ConsoleLogger, TensorboardLogger +from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model as setup_vocoder_model @@ -52,6 +52,7 @@ if platform.system() != "Windows": rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) + if is_apex_available(): from apex import amp @@ -92,6 +93,7 @@ class Trainer: output_path: str, c_logger: ConsoleLogger = None, tb_logger: TensorboardLogger = None, + wandb_logger: WandbLogger = None, model: nn.Module = None, cudnn_benchmark: bool = False, ) -> None: @@ -119,6 +121,9 @@ class Trainer: tb_logger (TensorboardLogger, optional): Tensorboard logger. If not provided, the default logger is used. Defaults to None. + wandb_logger (WandbLogger, optional): W&B logger. If not provided, the default logger is used. + Defaults to None. + model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer` initializes a model from the provided config. Defaults to None. @@ -167,6 +172,16 @@ class Trainer: self.tb_logger.tb_add_text("model-config", f"
{config.to_json()}
", 0) else: self.tb_logger = tb_logger + + if wandb_logger is None: + self.wandb_logger = WandbLogger( + project=os.path(output_path).stem, + name=config.model, + config=config, + ) + else: + self.wandb_logger = wandb_logger + log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") self._setup_logger_config(log_file) @@ -627,6 +642,7 @@ class Trainer: # reduce TB load and don't log every step if self.total_steps_done % self.config.tb_plot_step == 0: self.tb_logger.tb_train_step_stats(self.total_steps_done, loss_dict) + self.wandb_logger.log(loss_dict, "train/", flush=True) if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0: if self.config.checkpoint: # checkpoint the model @@ -649,10 +665,14 @@ class Trainer: figures, audios = self.model.train_log(self.ap, batch, outputs) if figures is not None: self.tb_logger.tb_train_figures(self.total_steps_done, figures) + self.wandb_logger.log_figures(figures, "train/") if audios is not None: self.tb_logger.tb_train_audios(self.total_steps_done, audios, self.ap.sample_rate) + self.wandb_logger.log_audios(audios, self.ap.sample_rate, "train/") + self.total_steps_done += 1 self.callbacks.on_train_step_end() + self.wandb_logger.flush() return outputs, loss_dict def train_epoch(self) -> None: @@ -678,6 +698,7 @@ class Trainer: epoch_stats = {"epoch_time": epoch_time} epoch_stats.update(self.keep_avg_train.avg_values) self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats) + self.wandb_logger.log_scalars(epoch_stats, "train/") if self.config.tb_model_param_stats: self.tb_logger.tb_model_weights(self.model, self.total_steps_done) # scheduler step after the epoch @@ -781,9 +802,12 @@ class Trainer: figures, audios = self.model.eval_log(self.ap, batch, outputs) if figures is not None: self.tb_logger.tb_eval_figures(self.total_steps_done, figures) + self.wandb_logger.log_figures(figures, "eval/") if audios is not None: self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate) + self.wandb_logger.log_audios(audios, self.ap.sample_rate, "eval/") self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values) + self.wandb_logger.log_scalars(self.keep_avg_eval.avg_values, "eval/") def test_run(self) -> None: """Run test and log the results. Test run must be defined by the model. @@ -803,6 +827,8 @@ class Trainer: figures, audios = self.model.test_run(self.ap) self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.tb_logger.tb_test_figures(self.total_steps_done, figures) + self.wandb_logger.log_audios(audios, self.config.audio["sample_rate"], "test/") + self.wandb_logger.log_figures(figures, "test/") def _fit(self) -> None: """🏃 train -> evaluate -> test for the number of epochs.""" @@ -1066,6 +1092,7 @@ def process_args(args, config=None): logging to the console. tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does the TensorBoard logging. + wandb_logger (TTS.utils.tensorboard.WandbLogger): Class that does the W&B Loggin TODO: - Interactive config definition. @@ -1079,14 +1106,17 @@ def process_args(args, config=None): args.restore_path, best_model = get_last_checkpoint(args.continue_path) if not args.best_path: args.best_path = best_model - # init config if not already defined + # setup output paths and read configs + if config is None: + config = load_config(args.config_path) + # init config if config is None: if args.config_path: # init from a file config = load_config(args.config_path) else: # init from console args - from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + from TTS.config.shared_configs import BaseTrainingConfig config_base = BaseTrainingConfig() config_base.parse_known_args(coqpit_overrides) @@ -1101,6 +1131,7 @@ def process_args(args, config=None): audio_path = os.path.join(experiment_path, "test_audios") # setup rank 0 process in distributed training tb_logger = None + wandb_logger = None if args.rank == 0: new_fields = {} if args.restore_path: @@ -1116,8 +1147,15 @@ def process_args(args, config=None): tb_logger = TensorboardLogger(experiment_path, model_name=config.model) # write model desc to tensorboard tb_logger.tb_add_text("model-config", f"
{config.to_json()}
", 0) + + wandb_logger = WandbLogger( + project=config.model, + name=os.path.basename(experiment_path), + config=config, + ) + c_logger = ConsoleLogger() - return config, experiment_path, audio_path, c_logger, tb_logger + return config, experiment_path, audio_path, c_logger, tb_logger, wandb_logger def init_arguments(): @@ -1133,5 +1171,5 @@ def init_training(argv: Union[List, Coqpit], config: Coqpit = None): else: parser = init_arguments() args = parser.parse_known_args() - config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, config) - return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger + config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger = process_args(args, config) + return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger diff --git a/TTS/utils/logging/__init__.py b/TTS/utils/logging/__init__.py index 877131c4..a39bb912 100644 --- a/TTS/utils/logging/__init__.py +++ b/TTS/utils/logging/__init__.py @@ -1,2 +1,3 @@ from TTS.utils.logging.console_logger import ConsoleLogger from TTS.utils.logging.tensorboard_logger import TensorboardLogger +from TTS.utils.logging.wandb_logger import WandbLogger diff --git a/TTS/utils/logging/wandb_logger.py b/TTS/utils/logging/wandb_logger.py new file mode 100644 index 00000000..4d6f0c30 --- /dev/null +++ b/TTS/utils/logging/wandb_logger.py @@ -0,0 +1,69 @@ +from pathlib import Path + +try: + import wandb + from wandb import init, finish +except ImportError: + wandb = None + + +class WandbLogger: + def __init__(self, **kwargs): + self.run = None + if wandb: + self.run = wandb.init(**kwargs) if not wandb.run else wandb.run + self.log_dict = {} + + def log(self, log_dict, prefix="", flush=False): + """ + This function accumulates data in self.log_dict. If flush is set. + the accumulated metrics will be logged directly to wandb dashboard. + """ + for key, value in log_dict.items(): + self.log_dict[prefix + key] = value + if flush: # for cases where you don't want to accumulate data + self.flush() + + def log_scalars(self, log_dict, prefix=""): + if not self.run: + return + + for key, value in log_dict.items(): + self.log_dict[prefix + key] = value + + def log_audios(self, log_dict, sample_rate, prefix=""): + if not self.run: + return + + prefix = "audios/" + prefix + for key, value in log_dict.items(): + self.log_dict[prefix + key] = wandb.Audio(value, sample_rate=int(sample_rate)) + + def log_figures(self, log_dict, prefix=""): + if not self.run: + return + + prefix = "figures/" + prefix + for key, value in log_dict.items(): + self.log_dict[prefix + key] = wandb.Image(value) + + def flush(self): + if self.run: + wandb.log(self.log_dict) + self.log_dict = {} + + def finish(self): + """ + Finish this W&B run + """ + self.run.finish() + + def log_artifact(self, file_or_dir, name, type, aliases=[]): + artifact = wandb.Artifact(name, type=type) + data_path = Path(file_or_dir) + if data_path.is_dir(): + artifact.add_dir(str(data_path)) + elif data_path.is_file(): + artifact.add_file(str(data_path)) + + self.run.log_artifact(artifact, aliases=aliases) diff --git a/recipes/ljspeech/align_tts/train_aligntts.py b/recipes/ljspeech/align_tts/train_aligntts.py index 85c22673..4ef215fc 100644 --- a/recipes/ljspeech/align_tts/train_aligntts.py +++ b/recipes/ljspeech/align_tts/train_aligntts.py @@ -25,6 +25,6 @@ config = AlignTTSConfig( output_path=output_path, datasets=[dataset_config], ) -args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger) +args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger) trainer.fit() diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index f77997e8..648bac75 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -25,6 +25,6 @@ config = GlowTTSConfig( output_path=output_path, datasets=[dataset_config], ) -args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger) +args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger) trainer.fit() diff --git a/recipes/ljspeech/hifigan/train_hifigan.py b/recipes/ljspeech/hifigan/train_hifigan.py index af615ace..a90bc52a 100644 --- a/recipes/ljspeech/hifigan/train_hifigan.py +++ b/recipes/ljspeech/hifigan/train_hifigan.py @@ -24,6 +24,6 @@ config = HifiganConfig( data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) -args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger) +args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger) trainer.fit() diff --git a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py index 9d1b9e6f..e3cd2244 100644 --- a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py +++ b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py @@ -24,6 +24,6 @@ config = MultibandMelganConfig( data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) -args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger) +args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger) trainer.fit() diff --git a/recipes/ljspeech/univnet/train.py b/recipes/ljspeech/univnet/train.py index d8f33ae3..163e1bba 100644 --- a/recipes/ljspeech/univnet/train.py +++ b/recipes/ljspeech/univnet/train.py @@ -24,6 +24,6 @@ config = UnivnetConfig( data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) -args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger) +args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger) trainer.fit() diff --git a/recipes/ljspeech/wavegrad/train_wavegrad.py b/recipes/ljspeech/wavegrad/train_wavegrad.py index 4f82f50b..2939613e 100644 --- a/recipes/ljspeech/wavegrad/train_wavegrad.py +++ b/recipes/ljspeech/wavegrad/train_wavegrad.py @@ -22,6 +22,6 @@ config = WavegradConfig( data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) -args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger) +args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger) trainer.fit() diff --git a/recipes/ljspeech/wavernn/train_wavernn.py b/recipes/ljspeech/wavernn/train_wavernn.py index 7222f78a..04353d79 100644 --- a/recipes/ljspeech/wavernn/train_wavernn.py +++ b/recipes/ljspeech/wavernn/train_wavernn.py @@ -24,6 +24,6 @@ config = WavernnConfig( data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) -args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=True) +args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger, cudnn_benchmark=True) trainer.fit()