WandbLogger

This commit is contained in:
Ayush Chaurasia 2021-07-05 06:42:36 +00:00 committed by Eren Gölge
parent 06018251e6
commit f5e50ad502
10 changed files with 128 additions and 20 deletions

View File

@ -38,7 +38,7 @@ from TTS.utils.generic_utils import (
to_cuda, to_cuda,
) )
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
from TTS.utils.logging import ConsoleLogger, TensorboardLogger from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.models import setup_model as setup_vocoder_model
@ -52,6 +52,7 @@ if platform.system() != "Windows":
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
if is_apex_available(): if is_apex_available():
from apex import amp from apex import amp
@ -92,6 +93,7 @@ class Trainer:
output_path: str, output_path: str,
c_logger: ConsoleLogger = None, c_logger: ConsoleLogger = None,
tb_logger: TensorboardLogger = None, tb_logger: TensorboardLogger = None,
wandb_logger: WandbLogger = None,
model: nn.Module = None, model: nn.Module = None,
cudnn_benchmark: bool = False, cudnn_benchmark: bool = False,
) -> None: ) -> None:
@ -119,6 +121,9 @@ class Trainer:
tb_logger (TensorboardLogger, optional): Tensorboard logger. If not provided, the default logger is used. tb_logger (TensorboardLogger, optional): Tensorboard logger. If not provided, the default logger is used.
Defaults to None. Defaults to None.
wandb_logger (WandbLogger, optional): W&B logger. If not provided, the default logger is used.
Defaults to None.
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer` model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
initializes a model from the provided config. Defaults to None. initializes a model from the provided config. Defaults to None.
@ -167,6 +172,16 @@ class Trainer:
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0) self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
else: else:
self.tb_logger = tb_logger self.tb_logger = tb_logger
if wandb_logger is None:
self.wandb_logger = WandbLogger(
project=os.path(output_path).stem,
name=config.model,
config=config,
)
else:
self.wandb_logger = wandb_logger
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file) self._setup_logger_config(log_file)
@ -627,6 +642,7 @@ class Trainer:
# reduce TB load and don't log every step # reduce TB load and don't log every step
if self.total_steps_done % self.config.tb_plot_step == 0: if self.total_steps_done % self.config.tb_plot_step == 0:
self.tb_logger.tb_train_step_stats(self.total_steps_done, loss_dict) self.tb_logger.tb_train_step_stats(self.total_steps_done, loss_dict)
self.wandb_logger.log(loss_dict, "train/", flush=True)
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0: if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
if self.config.checkpoint: if self.config.checkpoint:
# checkpoint the model # checkpoint the model
@ -649,10 +665,14 @@ class Trainer:
figures, audios = self.model.train_log(self.ap, batch, outputs) figures, audios = self.model.train_log(self.ap, batch, outputs)
if figures is not None: if figures is not None:
self.tb_logger.tb_train_figures(self.total_steps_done, figures) self.tb_logger.tb_train_figures(self.total_steps_done, figures)
self.wandb_logger.log_figures(figures, "train/")
if audios is not None: if audios is not None:
self.tb_logger.tb_train_audios(self.total_steps_done, audios, self.ap.sample_rate) self.tb_logger.tb_train_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.wandb_logger.log_audios(audios, self.ap.sample_rate, "train/")
self.total_steps_done += 1 self.total_steps_done += 1
self.callbacks.on_train_step_end() self.callbacks.on_train_step_end()
self.wandb_logger.flush()
return outputs, loss_dict return outputs, loss_dict
def train_epoch(self) -> None: def train_epoch(self) -> None:
@ -678,6 +698,7 @@ class Trainer:
epoch_stats = {"epoch_time": epoch_time} epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(self.keep_avg_train.avg_values) epoch_stats.update(self.keep_avg_train.avg_values)
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats) self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
self.wandb_logger.log_scalars(epoch_stats, "train/")
if self.config.tb_model_param_stats: if self.config.tb_model_param_stats:
self.tb_logger.tb_model_weights(self.model, self.total_steps_done) self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
# scheduler step after the epoch # scheduler step after the epoch
@ -781,9 +802,12 @@ class Trainer:
figures, audios = self.model.eval_log(self.ap, batch, outputs) figures, audios = self.model.eval_log(self.ap, batch, outputs)
if figures is not None: if figures is not None:
self.tb_logger.tb_eval_figures(self.total_steps_done, figures) self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
self.wandb_logger.log_figures(figures, "eval/")
if audios is not None: if audios is not None:
self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate) self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.wandb_logger.log_audios(audios, self.ap.sample_rate, "eval/")
self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values) self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
self.wandb_logger.log_scalars(self.keep_avg_eval.avg_values, "eval/")
def test_run(self) -> None: def test_run(self) -> None:
"""Run test and log the results. Test run must be defined by the model. """Run test and log the results. Test run must be defined by the model.
@ -803,6 +827,8 @@ class Trainer:
figures, audios = self.model.test_run(self.ap) figures, audios = self.model.test_run(self.ap)
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, figures) self.tb_logger.tb_test_figures(self.total_steps_done, figures)
self.wandb_logger.log_audios(audios, self.config.audio["sample_rate"], "test/")
self.wandb_logger.log_figures(figures, "test/")
def _fit(self) -> None: def _fit(self) -> None:
"""🏃 train -> evaluate -> test for the number of epochs.""" """🏃 train -> evaluate -> test for the number of epochs."""
@ -1066,6 +1092,7 @@ def process_args(args, config=None):
logging to the console. logging to the console.
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
the TensorBoard logging. the TensorBoard logging.
wandb_logger (TTS.utils.tensorboard.WandbLogger): Class that does the W&B Loggin
TODO: TODO:
- Interactive config definition. - Interactive config definition.
@ -1079,14 +1106,17 @@ def process_args(args, config=None):
args.restore_path, best_model = get_last_checkpoint(args.continue_path) args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path: if not args.best_path:
args.best_path = best_model args.best_path = best_model
# init config if not already defined # setup output paths and read configs
if config is None:
config = load_config(args.config_path)
# init config
if config is None: if config is None:
if args.config_path: if args.config_path:
# init from a file # init from a file
config = load_config(args.config_path) config = load_config(args.config_path)
else: else:
# init from console args # init from console args
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel from TTS.config.shared_configs import BaseTrainingConfig
config_base = BaseTrainingConfig() config_base = BaseTrainingConfig()
config_base.parse_known_args(coqpit_overrides) config_base.parse_known_args(coqpit_overrides)
@ -1101,6 +1131,7 @@ def process_args(args, config=None):
audio_path = os.path.join(experiment_path, "test_audios") audio_path = os.path.join(experiment_path, "test_audios")
# setup rank 0 process in distributed training # setup rank 0 process in distributed training
tb_logger = None tb_logger = None
wandb_logger = None
if args.rank == 0: if args.rank == 0:
new_fields = {} new_fields = {}
if args.restore_path: if args.restore_path:
@ -1116,8 +1147,15 @@ def process_args(args, config=None):
tb_logger = TensorboardLogger(experiment_path, model_name=config.model) tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
# write model desc to tensorboard # write model desc to tensorboard
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0) tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
wandb_logger = WandbLogger(
project=config.model,
name=os.path.basename(experiment_path),
config=config,
)
c_logger = ConsoleLogger() c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, tb_logger return config, experiment_path, audio_path, c_logger, tb_logger, wandb_logger
def init_arguments(): def init_arguments():
@ -1133,5 +1171,5 @@ def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
else: else:
parser = init_arguments() parser = init_arguments()
args = parser.parse_known_args() args = parser.parse_known_args()
config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, config) config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger = process_args(args, config)
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger

View File

@ -1,2 +1,3 @@
from TTS.utils.logging.console_logger import ConsoleLogger from TTS.utils.logging.console_logger import ConsoleLogger
from TTS.utils.logging.tensorboard_logger import TensorboardLogger from TTS.utils.logging.tensorboard_logger import TensorboardLogger
from TTS.utils.logging.wandb_logger import WandbLogger

View File

@ -0,0 +1,69 @@
from pathlib import Path
try:
import wandb
from wandb import init, finish
except ImportError:
wandb = None
class WandbLogger:
def __init__(self, **kwargs):
self.run = None
if wandb:
self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
self.log_dict = {}
def log(self, log_dict, prefix="", flush=False):
"""
This function accumulates data in self.log_dict. If flush is set.
the accumulated metrics will be logged directly to wandb dashboard.
"""
for key, value in log_dict.items():
self.log_dict[prefix + key] = value
if flush: # for cases where you don't want to accumulate data
self.flush()
def log_scalars(self, log_dict, prefix=""):
if not self.run:
return
for key, value in log_dict.items():
self.log_dict[prefix + key] = value
def log_audios(self, log_dict, sample_rate, prefix=""):
if not self.run:
return
prefix = "audios/" + prefix
for key, value in log_dict.items():
self.log_dict[prefix + key] = wandb.Audio(value, sample_rate=int(sample_rate))
def log_figures(self, log_dict, prefix=""):
if not self.run:
return
prefix = "figures/" + prefix
for key, value in log_dict.items():
self.log_dict[prefix + key] = wandb.Image(value)
def flush(self):
if self.run:
wandb.log(self.log_dict)
self.log_dict = {}
def finish(self):
"""
Finish this W&B run
"""
self.run.finish()
def log_artifact(self, file_or_dir, name, type, aliases=[]):
artifact = wandb.Artifact(name, type=type)
data_path = Path(file_or_dir)
if data_path.is_dir():
artifact.add_dir(str(data_path))
elif data_path.is_file():
artifact.add_file(str(data_path))
self.run.log_artifact(artifact, aliases=aliases)

View File

@ -25,6 +25,6 @@ config = AlignTTSConfig(
output_path=output_path, output_path=output_path,
datasets=[dataset_config], datasets=[dataset_config],
) )
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
trainer.fit() trainer.fit()

View File

@ -25,6 +25,6 @@ config = GlowTTSConfig(
output_path=output_path, output_path=output_path,
datasets=[dataset_config], datasets=[dataset_config],
) )
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
trainer.fit() trainer.fit()

View File

@ -24,6 +24,6 @@ config = HifiganConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path, output_path=output_path,
) )
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
trainer.fit() trainer.fit()

View File

@ -24,6 +24,6 @@ config = MultibandMelganConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path, output_path=output_path,
) )
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
trainer.fit() trainer.fit()

View File

@ -24,6 +24,6 @@ config = UnivnetConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path, output_path=output_path,
) )
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
trainer.fit() trainer.fit()

View File

@ -22,6 +22,6 @@ config = WavegradConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path, output_path=output_path,
) )
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
trainer.fit() trainer.fit()

View File

@ -24,6 +24,6 @@ config = WavernnConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path, output_path=output_path,
) )
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=True) trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger, cudnn_benchmark=True)
trainer.fit() trainer.fit()