mirror of https://github.com/coqui-ai/TTS.git
WandbLogger
This commit is contained in:
parent
06018251e6
commit
f5e50ad502
|
@ -38,7 +38,7 @@ from TTS.utils.generic_utils import (
|
|||
to_cuda,
|
||||
)
|
||||
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||
|
@ -52,6 +52,7 @@ if platform.system() != "Windows":
|
|||
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
|
||||
|
||||
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
|
||||
|
@ -92,6 +93,7 @@ class Trainer:
|
|||
output_path: str,
|
||||
c_logger: ConsoleLogger = None,
|
||||
tb_logger: TensorboardLogger = None,
|
||||
wandb_logger: WandbLogger = None,
|
||||
model: nn.Module = None,
|
||||
cudnn_benchmark: bool = False,
|
||||
) -> None:
|
||||
|
@ -119,6 +121,9 @@ class Trainer:
|
|||
tb_logger (TensorboardLogger, optional): Tensorboard logger. If not provided, the default logger is used.
|
||||
Defaults to None.
|
||||
|
||||
wandb_logger (WandbLogger, optional): W&B logger. If not provided, the default logger is used.
|
||||
Defaults to None.
|
||||
|
||||
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
|
||||
initializes a model from the provided config. Defaults to None.
|
||||
|
||||
|
@ -167,6 +172,16 @@ class Trainer:
|
|||
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||
else:
|
||||
self.tb_logger = tb_logger
|
||||
|
||||
if wandb_logger is None:
|
||||
self.wandb_logger = WandbLogger(
|
||||
project=os.path(output_path).stem,
|
||||
name=config.model,
|
||||
config=config,
|
||||
)
|
||||
else:
|
||||
self.wandb_logger = wandb_logger
|
||||
|
||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||
self._setup_logger_config(log_file)
|
||||
|
||||
|
@ -627,6 +642,7 @@ class Trainer:
|
|||
# reduce TB load and don't log every step
|
||||
if self.total_steps_done % self.config.tb_plot_step == 0:
|
||||
self.tb_logger.tb_train_step_stats(self.total_steps_done, loss_dict)
|
||||
self.wandb_logger.log(loss_dict, "train/", flush=True)
|
||||
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
|
||||
if self.config.checkpoint:
|
||||
# checkpoint the model
|
||||
|
@ -649,10 +665,14 @@ class Trainer:
|
|||
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
||||
if figures is not None:
|
||||
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
|
||||
self.wandb_logger.log_figures(figures, "train/")
|
||||
if audios is not None:
|
||||
self.tb_logger.tb_train_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||
self.wandb_logger.log_audios(audios, self.ap.sample_rate, "train/")
|
||||
|
||||
self.total_steps_done += 1
|
||||
self.callbacks.on_train_step_end()
|
||||
self.wandb_logger.flush()
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_epoch(self) -> None:
|
||||
|
@ -678,6 +698,7 @@ class Trainer:
|
|||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(self.keep_avg_train.avg_values)
|
||||
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
|
||||
self.wandb_logger.log_scalars(epoch_stats, "train/")
|
||||
if self.config.tb_model_param_stats:
|
||||
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
||||
# scheduler step after the epoch
|
||||
|
@ -781,9 +802,12 @@ class Trainer:
|
|||
figures, audios = self.model.eval_log(self.ap, batch, outputs)
|
||||
if figures is not None:
|
||||
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
||||
self.wandb_logger.log_figures(figures, "eval/")
|
||||
if audios is not None:
|
||||
self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||
self.wandb_logger.log_audios(audios, self.ap.sample_rate, "eval/")
|
||||
self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
|
||||
self.wandb_logger.log_scalars(self.keep_avg_eval.avg_values, "eval/")
|
||||
|
||||
def test_run(self) -> None:
|
||||
"""Run test and log the results. Test run must be defined by the model.
|
||||
|
@ -803,6 +827,8 @@ class Trainer:
|
|||
figures, audios = self.model.test_run(self.ap)
|
||||
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
|
||||
self.wandb_logger.log_audios(audios, self.config.audio["sample_rate"], "test/")
|
||||
self.wandb_logger.log_figures(figures, "test/")
|
||||
|
||||
def _fit(self) -> None:
|
||||
"""🏃 train -> evaluate -> test for the number of epochs."""
|
||||
|
@ -1066,6 +1092,7 @@ def process_args(args, config=None):
|
|||
logging to the console.
|
||||
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
|
||||
the TensorBoard logging.
|
||||
wandb_logger (TTS.utils.tensorboard.WandbLogger): Class that does the W&B Loggin
|
||||
|
||||
TODO:
|
||||
- Interactive config definition.
|
||||
|
@ -1079,14 +1106,17 @@ def process_args(args, config=None):
|
|||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||
if not args.best_path:
|
||||
args.best_path = best_model
|
||||
# init config if not already defined
|
||||
# setup output paths and read configs
|
||||
if config is None:
|
||||
config = load_config(args.config_path)
|
||||
# init config
|
||||
if config is None:
|
||||
if args.config_path:
|
||||
# init from a file
|
||||
config = load_config(args.config_path)
|
||||
else:
|
||||
# init from console args
|
||||
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
|
||||
from TTS.config.shared_configs import BaseTrainingConfig
|
||||
|
||||
config_base = BaseTrainingConfig()
|
||||
config_base.parse_known_args(coqpit_overrides)
|
||||
|
@ -1101,6 +1131,7 @@ def process_args(args, config=None):
|
|||
audio_path = os.path.join(experiment_path, "test_audios")
|
||||
# setup rank 0 process in distributed training
|
||||
tb_logger = None
|
||||
wandb_logger = None
|
||||
if args.rank == 0:
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
|
@ -1116,8 +1147,15 @@ def process_args(args, config=None):
|
|||
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||
|
||||
wandb_logger = WandbLogger(
|
||||
project=config.model,
|
||||
name=os.path.basename(experiment_path),
|
||||
config=config,
|
||||
)
|
||||
|
||||
c_logger = ConsoleLogger()
|
||||
return config, experiment_path, audio_path, c_logger, tb_logger
|
||||
return config, experiment_path, audio_path, c_logger, tb_logger, wandb_logger
|
||||
|
||||
|
||||
def init_arguments():
|
||||
|
@ -1133,5 +1171,5 @@ def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
|
|||
else:
|
||||
parser = init_arguments()
|
||||
args = parser.parse_known_args()
|
||||
config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, config)
|
||||
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger
|
||||
config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger = process_args(args, config)
|
||||
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
from TTS.utils.logging.console_logger import ConsoleLogger
|
||||
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.logging.wandb_logger import WandbLogger
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import wandb
|
||||
from wandb import init, finish
|
||||
except ImportError:
|
||||
wandb = None
|
||||
|
||||
|
||||
class WandbLogger:
|
||||
def __init__(self, **kwargs):
|
||||
self.run = None
|
||||
if wandb:
|
||||
self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
|
||||
self.log_dict = {}
|
||||
|
||||
def log(self, log_dict, prefix="", flush=False):
|
||||
"""
|
||||
This function accumulates data in self.log_dict. If flush is set.
|
||||
the accumulated metrics will be logged directly to wandb dashboard.
|
||||
"""
|
||||
for key, value in log_dict.items():
|
||||
self.log_dict[prefix + key] = value
|
||||
if flush: # for cases where you don't want to accumulate data
|
||||
self.flush()
|
||||
|
||||
def log_scalars(self, log_dict, prefix=""):
|
||||
if not self.run:
|
||||
return
|
||||
|
||||
for key, value in log_dict.items():
|
||||
self.log_dict[prefix + key] = value
|
||||
|
||||
def log_audios(self, log_dict, sample_rate, prefix=""):
|
||||
if not self.run:
|
||||
return
|
||||
|
||||
prefix = "audios/" + prefix
|
||||
for key, value in log_dict.items():
|
||||
self.log_dict[prefix + key] = wandb.Audio(value, sample_rate=int(sample_rate))
|
||||
|
||||
def log_figures(self, log_dict, prefix=""):
|
||||
if not self.run:
|
||||
return
|
||||
|
||||
prefix = "figures/" + prefix
|
||||
for key, value in log_dict.items():
|
||||
self.log_dict[prefix + key] = wandb.Image(value)
|
||||
|
||||
def flush(self):
|
||||
if self.run:
|
||||
wandb.log(self.log_dict)
|
||||
self.log_dict = {}
|
||||
|
||||
def finish(self):
|
||||
"""
|
||||
Finish this W&B run
|
||||
"""
|
||||
self.run.finish()
|
||||
|
||||
def log_artifact(self, file_or_dir, name, type, aliases=[]):
|
||||
artifact = wandb.Artifact(name, type=type)
|
||||
data_path = Path(file_or_dir)
|
||||
if data_path.is_dir():
|
||||
artifact.add_dir(str(data_path))
|
||||
elif data_path.is_file():
|
||||
artifact.add_file(str(data_path))
|
||||
|
||||
self.run.log_artifact(artifact, aliases=aliases)
|
|
@ -25,6 +25,6 @@ config = AlignTTSConfig(
|
|||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
||||
trainer.fit()
|
||||
|
|
|
@ -25,6 +25,6 @@ config = GlowTTSConfig(
|
|||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
||||
trainer.fit()
|
||||
|
|
|
@ -24,6 +24,6 @@ config = HifiganConfig(
|
|||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||
output_path=output_path,
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
||||
trainer.fit()
|
||||
|
|
|
@ -24,6 +24,6 @@ config = MultibandMelganConfig(
|
|||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||
output_path=output_path,
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
||||
trainer.fit()
|
||||
|
|
|
@ -24,6 +24,6 @@ config = UnivnetConfig(
|
|||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||
output_path=output_path,
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
||||
trainer.fit()
|
||||
|
|
|
@ -22,6 +22,6 @@ config = WavegradConfig(
|
|||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||
output_path=output_path,
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
||||
trainer.fit()
|
||||
|
|
|
@ -24,6 +24,6 @@ config = WavernnConfig(
|
|||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||
output_path=output_path,
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=True)
|
||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger, cudnn_benchmark=True)
|
||||
trainer.fit()
|
||||
|
|
Loading…
Reference in New Issue