From 936a47504d87dfa308daef4ba4134ac836abce5a Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Fri, 9 Jul 2021 17:43:15 +0530 Subject: [PATCH] Update Logger API, recipes --- TTS/bin/train_encoder.py | 4 ++-- TTS/trainer.py | 4 +--- TTS/utils/logging/logger_base.py | 0 TTS/utils/logging/tensorboard_logger.py | 14 ++++++-------- TTS/utils/logging/wandb_logger.py | 13 ++++++------- recipes/ljspeech/align_tts/train_aligntts.py | 4 ++-- recipes/ljspeech/glow_tts/train_glowtts.py | 4 ++-- recipes/ljspeech/hifigan/train_hifigan.py | 4 ++-- .../multiband_melgan/train_multiband_melgan.py | 4 ++-- recipes/ljspeech/univnet/train.py | 4 ++-- recipes/ljspeech/wavegrad/train_wavegrad.py | 4 ++-- recipes/ljspeech/wavernn/train_wavernn.py | 4 ++-- 12 files changed, 29 insertions(+), 34 deletions(-) create mode 100644 TTS/utils/logging/logger_base.py diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index d419d50e..7ff35486 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -116,12 +116,12 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step): "step_time": step_time, "avg_loader_time": avg_loader_time, } - tb_logger.tb_train_epoch_stats(global_step, train_stats) + dashboard_logger.train_epoch_stats(global_step, train_stats) figures = { # FIXME: not constant "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10), } - tb_logger.tb_train_figures(global_step, figures) + dashboard_logger.train_figures(global_step, figures) if global_step % c.print_step == 0: print( diff --git a/TTS/trainer.py b/TTS/trainer.py index 8edb75cf..f0a2b18e 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -184,7 +184,6 @@ class Trainer: if not self.config.log_model_step: self.config.log_model_step = self.config.save_step - log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") self._setup_logger_config(log_file) @@ -1147,7 +1146,7 @@ def process_args(args, config=None): os.chmod(experiment_path, 0o775) if config.dashboard_logger == "tensorboard": - dashboard_logger = TensorboardLogger(output_path, model_name=config.model) + dashboard_logger = TensorboardLogger(config.output_path, model_name=config.model) dashboard_logger.add_text("model-config", f"
{config.to_json()}
", 0) elif config.dashboard_logger == "wandb": @@ -1162,7 +1161,6 @@ def process_args(args, config=None): entity=config.wandb_entity, ) - c_logger = ConsoleLogger() return config, experiment_path, audio_path, c_logger, dashboard_logger diff --git a/TTS/utils/logging/logger_base.py b/TTS/utils/logging/logger_base.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/utils/logging/tensorboard_logger.py b/TTS/utils/logging/tensorboard_logger.py index ec57fc89..f5197edd 100644 --- a/TTS/utils/logging/tensorboard_logger.py +++ b/TTS/utils/logging/tensorboard_logger.py @@ -7,8 +7,6 @@ class TensorboardLogger(object): def __init__(self, log_dir, model_name): self.model_name = model_name self.writer = SummaryWriter(log_dir) - self.train_stats = {} - self.eval_stats = {} def model_weights(self, model, step): layer_num = 1 @@ -71,11 +69,11 @@ class TensorboardLogger(object): def add_text(self, title, text, step): self.writer.add_text(title, text, step) - def log_artifact(self, file_or_dir, name, artifact_type, aliases=None): - return - + def log_artifact(self, file_or_dir, name, artifact_type, aliases=None): # pylint: disable=W0613, R0201 + yield + def flush(self): - return - + self.writer.flush() + def finish(self): - return + self.writer.close() diff --git a/TTS/utils/logging/wandb_logger.py b/TTS/utils/logging/wandb_logger.py index 45129a86..f2fb6e1d 100644 --- a/TTS/utils/logging/wandb_logger.py +++ b/TTS/utils/logging/wandb_logger.py @@ -1,5 +1,7 @@ -from pathlib import Path +# pylint: disable=W0613 + import traceback +from pathlib import Path try: import wandb @@ -23,16 +25,14 @@ class WandbLogger: layer_num = 1 for name, param in model.named_parameters(): if param.numel() == 1: - self.dict_to_scalar("weights",{"layer{}-{}/value".format(layer_num, name): param.max()}) + self.dict_to_scalar("weights", {"layer{}-{}/value".format(layer_num, name): param.max()}) else: self.dict_to_scalar("weights", {"layer{}-{}/max".format(layer_num, name): param.max()}) self.dict_to_scalar("weights", {"layer{}-{}/min".format(layer_num, name): param.min()}) self.dict_to_scalar("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()}) self.dict_to_scalar("weights", {"layer{}-{}/std".format(layer_num, name): param.std()}) - ''' - self.writer.add_histogram("layer{}-{}/param".format(layer_num, name), param, step) - self.writer.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step) - ''' + self.log_dict["weights/layer{}-{}/param".format(layer_num, name)] = wandb.Histogram(param) + self.log_dict["weights/layer{}-{}/grad".format(layer_num, name)] = wandb.Histogram(param.grad) layer_num += 1 def dict_to_scalar(self, scope_name, stats): @@ -52,7 +52,6 @@ class WandbLogger: except RuntimeError: traceback.print_exc() - def log(self, log_dict, prefix="", flush=False): for key, value in log_dict.items(): self.log_dict[prefix + key] = value diff --git a/recipes/ljspeech/align_tts/train_aligntts.py b/recipes/ljspeech/align_tts/train_aligntts.py index 4ef215fc..4e214f92 100644 --- a/recipes/ljspeech/align_tts/train_aligntts.py +++ b/recipes/ljspeech/align_tts/train_aligntts.py @@ -25,6 +25,6 @@ config = AlignTTSConfig( output_path=output_path, datasets=[dataset_config], ) -args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger) +args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) trainer.fit() diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index 648bac75..fbd54e88 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -25,6 +25,6 @@ config = GlowTTSConfig( output_path=output_path, datasets=[dataset_config], ) -args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger) +args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) trainer.fit() diff --git a/recipes/ljspeech/hifigan/train_hifigan.py b/recipes/ljspeech/hifigan/train_hifigan.py index a90bc52a..f50ef476 100644 --- a/recipes/ljspeech/hifigan/train_hifigan.py +++ b/recipes/ljspeech/hifigan/train_hifigan.py @@ -24,6 +24,6 @@ config = HifiganConfig( data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) -args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger) +args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) trainer.fit() diff --git a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py index e3cd2244..1473ec3c 100644 --- a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py +++ b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py @@ -24,6 +24,6 @@ config = MultibandMelganConfig( data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) -args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger) +args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) trainer.fit() diff --git a/recipes/ljspeech/univnet/train.py b/recipes/ljspeech/univnet/train.py index 163e1bba..e8979c92 100644 --- a/recipes/ljspeech/univnet/train.py +++ b/recipes/ljspeech/univnet/train.py @@ -24,6 +24,6 @@ config = UnivnetConfig( data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) -args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger) +args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) trainer.fit() diff --git a/recipes/ljspeech/wavegrad/train_wavegrad.py b/recipes/ljspeech/wavegrad/train_wavegrad.py index 2939613e..fe038915 100644 --- a/recipes/ljspeech/wavegrad/train_wavegrad.py +++ b/recipes/ljspeech/wavegrad/train_wavegrad.py @@ -22,6 +22,6 @@ config = WavegradConfig( data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) -args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger) +args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) trainer.fit() diff --git a/recipes/ljspeech/wavernn/train_wavernn.py b/recipes/ljspeech/wavernn/train_wavernn.py index 04353d79..8f138298 100644 --- a/recipes/ljspeech/wavernn/train_wavernn.py +++ b/recipes/ljspeech/wavernn/train_wavernn.py @@ -24,6 +24,6 @@ config = WavernnConfig( data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) -args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger, cudnn_benchmark=True) +args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=True) trainer.fit()