mirror of https://github.com/coqui-ai/TTS.git
Unified logger API
This commit is contained in:
parent
f4434da5a3
commit
f63cf46c55
|
@ -208,7 +208,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger = init_training(sys.argv)
|
||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training(sys.argv)
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -5,8 +5,8 @@ from TTS.trainer import Trainer, init_training
|
|||
|
||||
def main():
|
||||
"""Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```"""
|
||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(sys.argv)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger, cudnn_benchmark=False)
|
||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv)
|
||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=False)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
|
|
|
@ -8,8 +8,8 @@ from TTS.utils.generic_utils import remove_experiment_folder
|
|||
|
||||
def main():
|
||||
try:
|
||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(sys.argv)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv)
|
||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||
trainer.fit()
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(output_path)
|
||||
|
|
|
@ -225,15 +225,17 @@ class BaseTrainingConfig(Coqpit):
|
|||
print_step (int):
|
||||
Number of steps required to print the next training log.
|
||||
|
||||
tb_plot_step (int):
|
||||
log_dashboard (str): "tensorboard" or "wandb"
|
||||
Set the experiment tracking tool
|
||||
|
||||
plot_step (int):
|
||||
Number of steps required to log training on Tensorboard.
|
||||
|
||||
tb_model_param_stats (bool):
|
||||
model_param_stats (bool):
|
||||
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
|
||||
Defaults to ```False```.
|
||||
wandb_disabled: bool = False
|
||||
|
||||
wandb_project_name (str):
|
||||
project_name (str):
|
||||
Name of the W&B project. Defaults to config.model
|
||||
|
||||
wandb_entity (str):
|
||||
|
@ -283,13 +285,13 @@ class BaseTrainingConfig(Coqpit):
|
|||
test_delay_epochs: int = 0
|
||||
print_eval: bool = False
|
||||
# logging
|
||||
dashboard_logger: str = "tensorboard"
|
||||
print_step: int = 25
|
||||
tb_plot_step: int = 100
|
||||
tb_model_param_stats: bool = False
|
||||
wandb_disabled: bool = False
|
||||
wandb_project_name: str = None
|
||||
plot_step: int = 100
|
||||
model_param_stats: bool = False
|
||||
project_name: str = None
|
||||
log_model_step: int = None
|
||||
wandb_entity: str = None
|
||||
wandb_log_model_step: int = None
|
||||
# checkpointing
|
||||
save_step: int = 10000
|
||||
checkpoint: bool = True
|
||||
|
|
145
TTS/trainer.py
145
TTS/trainer.py
|
@ -92,8 +92,7 @@ class Trainer:
|
|||
config: Coqpit,
|
||||
output_path: str,
|
||||
c_logger: ConsoleLogger = None,
|
||||
tb_logger: TensorboardLogger = None,
|
||||
wandb_logger: WandbLogger = None,
|
||||
dashboard_logger: Union[TensorboardLogger, WandbLogger] = None,
|
||||
model: nn.Module = None,
|
||||
cudnn_benchmark: bool = False,
|
||||
) -> None:
|
||||
|
@ -118,10 +117,7 @@ class Trainer:
|
|||
c_logger (ConsoleLogger, optional): Console logger for printing training status. If not provided, the default
|
||||
console logger is used. Defaults to None.
|
||||
|
||||
tb_logger (TensorboardLogger, optional): Tensorboard logger. If not provided, the default logger is used.
|
||||
Defaults to None.
|
||||
|
||||
wandb_logger (WandbLogger, optional): W&B logger. If not provided, the default logger is used.
|
||||
dashboard_logger Union[TensorboardLogger, WandbLogger]: Dashboard logger. If not provided, the tensorboard logger is used.
|
||||
Defaults to None.
|
||||
|
||||
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
|
||||
|
@ -143,8 +139,8 @@ class Trainer:
|
|||
Running trainer on a config.
|
||||
|
||||
>>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,)
|
||||
>>> args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
>>> trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
>>> args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||
>>> trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||
>>> trainer.fit()
|
||||
|
||||
TODO:
|
||||
|
@ -159,7 +155,7 @@ class Trainer:
|
|||
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
|
||||
if config is None:
|
||||
# parse config from console arguments
|
||||
config, output_path, _, c_logger, tb_logger, wandb_logger = process_args(args)
|
||||
config, output_path, _, c_logger, dashboard_logger = process_args(args)
|
||||
|
||||
self.output_path = output_path
|
||||
self.args = args
|
||||
|
@ -167,26 +163,27 @@ class Trainer:
|
|||
|
||||
# init loggers
|
||||
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
|
||||
if tb_logger is None:
|
||||
self.tb_logger = TensorboardLogger(output_path, model_name=config.model)
|
||||
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||
else:
|
||||
self.tb_logger = tb_logger
|
||||
self.dashboard_logger = dashboard_logger
|
||||
|
||||
if wandb_logger is None:
|
||||
wandb_project_name = config.model
|
||||
if config.wandb_project_name:
|
||||
wandb_project_name = config.wandb_project_name
|
||||
if self.dashboard_logger is None:
|
||||
if config.dashboard_logger == "tensorboard":
|
||||
self.dashboard_logger = TensorboardLogger(output_path, model_name=config.model)
|
||||
self.dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||
|
||||
elif config.dashboard_logger == "wandb":
|
||||
project_name = config.model
|
||||
if config.project_name:
|
||||
project_name = config.project_name
|
||||
|
||||
self.dashboard_logger = WandbLogger(
|
||||
project=project_name,
|
||||
name=config.run_name,
|
||||
config=config,
|
||||
entity=config.wandb_entity,
|
||||
)
|
||||
if not self.config.log_model_step:
|
||||
self.config.log_model_step = self.config.save_step
|
||||
|
||||
wandb_logger = WandbLogger(
|
||||
disabled=config.wandb_disabled,
|
||||
project=wandb_project_name,
|
||||
name=config.run_name,
|
||||
config=config,
|
||||
entity=config.wandb_entity,
|
||||
)
|
||||
else:
|
||||
self.wandb_logger = wandb_logger
|
||||
|
||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||
self._setup_logger_config(log_file)
|
||||
|
@ -646,9 +643,8 @@ class Trainer:
|
|||
if self.args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
# reduce TB load and don't log every step
|
||||
if self.total_steps_done % self.config.tb_plot_step == 0:
|
||||
self.tb_logger.tb_train_step_stats(self.total_steps_done, loss_dict)
|
||||
self.wandb_logger.log(loss_dict, "train/", flush=True)
|
||||
if self.total_steps_done % self.config.plot_step == 0:
|
||||
self.dashboard_logger.train_step_stats(self.total_steps_done, loss_dict)
|
||||
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
|
||||
if self.config.checkpoint:
|
||||
# checkpoint the model
|
||||
|
@ -664,13 +660,10 @@ class Trainer:
|
|||
model_loss=target_avg_loss,
|
||||
)
|
||||
|
||||
if (
|
||||
self.config.wandb_log_model_step
|
||||
and self.total_steps_done % self.config.wandb_log_model_step == 0
|
||||
):
|
||||
# log checkpoint as W&B artifact
|
||||
if self.total_steps_done % self.config.log_model_step == 0:
|
||||
# log checkpoint as artifact
|
||||
aliases = [f"epoch-{self.epochs_done}", f"step-{self.total_steps_done}"]
|
||||
self.wandb_logger.log_artifact(self.output_path, "checkpoint", "model", aliases)
|
||||
self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases)
|
||||
|
||||
# training visualizations
|
||||
figures, audios = None, None
|
||||
|
@ -679,15 +672,13 @@ class Trainer:
|
|||
elif hasattr(self.model, "train_log"):
|
||||
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
||||
if figures is not None:
|
||||
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
|
||||
self.wandb_logger.log_figures(figures, "train/")
|
||||
self.dashboard_logger.train_figures(self.total_steps_done, figures)
|
||||
if audios is not None:
|
||||
self.tb_logger.tb_train_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||
self.wandb_logger.log_audios(audios, self.ap.sample_rate, "train/")
|
||||
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||
|
||||
self.total_steps_done += 1
|
||||
self.callbacks.on_train_step_end()
|
||||
self.wandb_logger.flush()
|
||||
self.dashboard_logger.flush()
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_epoch(self) -> None:
|
||||
|
@ -712,10 +703,9 @@ class Trainer:
|
|||
if self.args.rank == 0:
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(self.keep_avg_train.avg_values)
|
||||
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
|
||||
self.wandb_logger.log_scalars(epoch_stats, "train/")
|
||||
if self.config.tb_model_param_stats:
|
||||
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
||||
self.dashboard_logger.train_epoch_stats(self.total_steps_done, epoch_stats)
|
||||
if self.config.model_param_stats:
|
||||
self.logger.model_weights(self.model, self.total_steps_done)
|
||||
# scheduler step after the epoch
|
||||
if self.scheduler is not None and self.config.scheduler_after_epoch:
|
||||
if isinstance(self.scheduler, list):
|
||||
|
@ -816,13 +806,10 @@ class Trainer:
|
|||
elif hasattr(self.model, "eval_log"):
|
||||
figures, audios = self.model.eval_log(self.ap, batch, outputs)
|
||||
if figures is not None:
|
||||
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
||||
self.wandb_logger.log_figures(figures, "eval/")
|
||||
self.dashboard_logger.eval_figures(self.total_steps_done, figures)
|
||||
if audios is not None:
|
||||
self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||
self.wandb_logger.log_audios(audios, self.ap.sample_rate, "eval/")
|
||||
self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
|
||||
self.wandb_logger.log_scalars(self.keep_avg_eval.avg_values, "eval/")
|
||||
self.dashboard_logger.eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||
self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
|
||||
|
||||
def test_run(self) -> None:
|
||||
"""Run test and log the results. Test run must be defined by the model.
|
||||
|
@ -839,11 +826,9 @@ class Trainer:
|
|||
samples = self.eval_loader.dataset.load_test_samples(1)
|
||||
figures, audios = self.model.test_run(self.ap, samples, None)
|
||||
else:
|
||||
figures, audios = self.model.test_run(self.ap)
|
||||
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
|
||||
self.wandb_logger.log_audios(audios, self.config.audio["sample_rate"], "test/")
|
||||
self.wandb_logger.log_figures(figures, "test/")
|
||||
figures, audios = self.model.test_run()
|
||||
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||
self.dashboard_logger.test_figures(self.total_steps_done, figures)
|
||||
|
||||
def _fit(self) -> None:
|
||||
"""🏃 train -> evaluate -> test for the number of epochs."""
|
||||
|
@ -875,13 +860,13 @@ class Trainer:
|
|||
"""Where the ✨️magic✨️ happens..."""
|
||||
try:
|
||||
self._fit()
|
||||
self.wandb_logger.finish()
|
||||
self.dashboard_logger.finish()
|
||||
except KeyboardInterrupt:
|
||||
self.callbacks.on_keyboard_interrupt()
|
||||
# if the output folder is empty remove the run.
|
||||
remove_experiment_folder(self.output_path)
|
||||
# finish the wandb run and sync data
|
||||
self.wandb_logger.finish()
|
||||
self.dashboard_logger.finish()
|
||||
# stop without error signal
|
||||
try:
|
||||
sys.exit(0)
|
||||
|
@ -1108,9 +1093,8 @@ def process_args(args, config=None):
|
|||
audio_path (str): Path to save generated test audios.
|
||||
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
||||
logging to the console.
|
||||
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
|
||||
the TensorBoard logging.
|
||||
wandb_logger (TTS.utils.tensorboard.WandbLogger): Class that does the W&B Logging
|
||||
|
||||
dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
|
||||
|
||||
TODO:
|
||||
- Interactive config definition.
|
||||
|
@ -1146,8 +1130,7 @@ def process_args(args, config=None):
|
|||
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
|
||||
audio_path = os.path.join(experiment_path, "test_audios")
|
||||
# setup rank 0 process in distributed training
|
||||
tb_logger = None
|
||||
wandb_logger = None
|
||||
dashboard_logger = None
|
||||
if args.rank == 0:
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
|
@ -1160,24 +1143,28 @@ def process_args(args, config=None):
|
|||
used_characters = parse_symbols()
|
||||
new_fields["characters"] = used_characters
|
||||
copy_model_files(config, experiment_path, new_fields)
|
||||
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||
os.chmod(audio_path, 0o775)
|
||||
os.chmod(experiment_path, 0o775)
|
||||
|
||||
wandb_project_name = config.model
|
||||
if config.wandb_project_name:
|
||||
wandb_project_name = config.wandb_project_name
|
||||
if config.dashboard_logger == "tensorboard":
|
||||
dashboard_logger = TensorboardLogger(output_path, model_name=config.model)
|
||||
dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||
|
||||
elif config.dashboard_logger == "wandb":
|
||||
project_name = config.model
|
||||
if config.project_name:
|
||||
project_name = config.project_name
|
||||
|
||||
dashboard_logger = WandbLogger(
|
||||
project=project_name,
|
||||
name=config.run_name,
|
||||
config=config,
|
||||
entity=config.wandb_entity,
|
||||
)
|
||||
|
||||
wandb_logger = WandbLogger(
|
||||
disabled=config.wandb_disabled,
|
||||
project=wandb_project_name,
|
||||
name=config.run_name,
|
||||
config=config,
|
||||
entity=config.wandb_entity,
|
||||
)
|
||||
|
||||
c_logger = ConsoleLogger()
|
||||
return config, experiment_path, audio_path, c_logger, tb_logger, wandb_logger
|
||||
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
||||
|
||||
|
||||
def init_arguments():
|
||||
|
@ -1193,5 +1180,5 @@ def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
|
|||
else:
|
||||
parser = init_arguments()
|
||||
args = parser.parse_known_args()
|
||||
config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger = process_args(args, config)
|
||||
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger
|
||||
config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config)
|
||||
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger
|
||||
|
|
|
@ -10,7 +10,7 @@ class TensorboardLogger(object):
|
|||
self.train_stats = {}
|
||||
self.eval_stats = {}
|
||||
|
||||
def tb_model_weights(self, model, step):
|
||||
def model_weights(self, model, step):
|
||||
layer_num = 1
|
||||
for name, param in model.named_parameters():
|
||||
if param.numel() == 1:
|
||||
|
@ -41,32 +41,41 @@ class TensorboardLogger(object):
|
|||
except RuntimeError:
|
||||
traceback.print_exc()
|
||||
|
||||
def tb_train_step_stats(self, step, stats):
|
||||
def train_step_stats(self, step, stats):
|
||||
self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step)
|
||||
|
||||
def tb_train_epoch_stats(self, step, stats):
|
||||
def train_epoch_stats(self, step, stats):
|
||||
self.dict_to_tb_scalar(f"{self.model_name}_TrainEpochStats", stats, step)
|
||||
|
||||
def tb_train_figures(self, step, figures):
|
||||
def train_figures(self, step, figures):
|
||||
self.dict_to_tb_figure(f"{self.model_name}_TrainFigures", figures, step)
|
||||
|
||||
def tb_train_audios(self, step, audios, sample_rate):
|
||||
def train_audios(self, step, audios, sample_rate):
|
||||
self.dict_to_tb_audios(f"{self.model_name}_TrainAudios", audios, step, sample_rate)
|
||||
|
||||
def tb_eval_stats(self, step, stats):
|
||||
def eval_stats(self, step, stats):
|
||||
self.dict_to_tb_scalar(f"{self.model_name}_EvalStats", stats, step)
|
||||
|
||||
def tb_eval_figures(self, step, figures):
|
||||
def eval_figures(self, step, figures):
|
||||
self.dict_to_tb_figure(f"{self.model_name}_EvalFigures", figures, step)
|
||||
|
||||
def tb_eval_audios(self, step, audios, sample_rate):
|
||||
def eval_audios(self, step, audios, sample_rate):
|
||||
self.dict_to_tb_audios(f"{self.model_name}_EvalAudios", audios, step, sample_rate)
|
||||
|
||||
def tb_test_audios(self, step, audios, sample_rate):
|
||||
def test_audios(self, step, audios, sample_rate):
|
||||
self.dict_to_tb_audios(f"{self.model_name}_TestAudios", audios, step, sample_rate)
|
||||
|
||||
def tb_test_figures(self, step, figures):
|
||||
def test_figures(self, step, figures):
|
||||
self.dict_to_tb_figure(f"{self.model_name}_TestFigures", figures, step)
|
||||
|
||||
def tb_add_text(self, title, text, step):
|
||||
def add_text(self, title, text, step):
|
||||
self.writer.add_text(title, text, step)
|
||||
|
||||
def log_artifact(self, file_or_dir, name, artifact_type, aliases=None):
|
||||
return
|
||||
|
||||
def flush(self):
|
||||
return
|
||||
|
||||
def finish(self):
|
||||
return
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from pathlib import Path
|
||||
import traceback
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
@ -8,40 +9,85 @@ except ImportError:
|
|||
|
||||
|
||||
class WandbLogger:
|
||||
def __init__(self, disabled=False, **kwargs):
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
if not wandb:
|
||||
raise Exception("install wandb using `pip install wandb` to use WandbLogger")
|
||||
|
||||
self.run = None
|
||||
if wandb and not disabled:
|
||||
self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
|
||||
self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
|
||||
self.model_name = self.run.config.model
|
||||
self.log_dict = {}
|
||||
|
||||
def model_weights(self, model):
|
||||
layer_num = 1
|
||||
for name, param in model.named_parameters():
|
||||
if param.numel() == 1:
|
||||
self.dict_to_scalar("weights",{"layer{}-{}/value".format(layer_num, name): param.max()})
|
||||
else:
|
||||
self.dict_to_scalar("weights", {"layer{}-{}/max".format(layer_num, name): param.max()})
|
||||
self.dict_to_scalar("weights", {"layer{}-{}/min".format(layer_num, name): param.min()})
|
||||
self.dict_to_scalar("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()})
|
||||
self.dict_to_scalar("weights", {"layer{}-{}/std".format(layer_num, name): param.std()})
|
||||
'''
|
||||
self.writer.add_histogram("layer{}-{}/param".format(layer_num, name), param, step)
|
||||
self.writer.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step)
|
||||
'''
|
||||
layer_num += 1
|
||||
|
||||
def dict_to_scalar(self, scope_name, stats):
|
||||
for key, value in stats.items():
|
||||
self.log_dict["{}/{}".format(scope_name, key)] = value
|
||||
|
||||
def dict_to_figure(self, scope_name, figures):
|
||||
for key, value in figures.items():
|
||||
self.log_dict["{}/{}".format(scope_name, key)] = wandb.Image(value)
|
||||
|
||||
def dict_to_audios(self, scope_name, audios, sample_rate):
|
||||
for key, value in audios.items():
|
||||
if value.dtype == "float16":
|
||||
value = value.astype("float32")
|
||||
try:
|
||||
self.log_dict["{}/{}".format(scope_name, key)] = wandb.Audio(value, sample_rate=sample_rate)
|
||||
except RuntimeError:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def log(self, log_dict, prefix="", flush=False):
|
||||
for key, value in log_dict.items():
|
||||
self.log_dict[prefix + key] = value
|
||||
if flush: # for cases where you don't want to accumulate data
|
||||
self.flush()
|
||||
|
||||
def log_scalars(self, log_dict, prefix=""):
|
||||
if not self.run:
|
||||
return
|
||||
def train_step_stats(self, step, stats):
|
||||
self.dict_to_scalar(f"{self.model_name}_TrainIterStats", stats)
|
||||
|
||||
for key, value in log_dict.items():
|
||||
self.log_dict[prefix + key] = value
|
||||
def train_epoch_stats(self, step, stats):
|
||||
self.dict_to_scalar(f"{self.model_name}_TrainEpochStats", stats)
|
||||
|
||||
def log_audios(self, log_dict, sample_rate, prefix=""):
|
||||
if not self.run:
|
||||
return
|
||||
def train_figures(self, step, figures):
|
||||
self.dict_to_figure(f"{self.model_name}_TrainFigures", figures)
|
||||
|
||||
prefix = "audios/" + prefix
|
||||
for key, value in log_dict.items():
|
||||
self.log_dict[prefix + key] = wandb.Audio(value, sample_rate=int(sample_rate))
|
||||
def train_audios(self, step, audios, sample_rate):
|
||||
self.dict_to_audios(f"{self.model_name}_TrainAudios", audios, sample_rate)
|
||||
|
||||
def log_figures(self, log_dict, prefix=""):
|
||||
if not self.run:
|
||||
return
|
||||
def eval_stats(self, step, stats):
|
||||
self.dict_to_scalar(f"{self.model_name}_EvalStats", stats)
|
||||
|
||||
prefix = "figures/" + prefix
|
||||
for key, value in log_dict.items():
|
||||
self.log_dict[prefix + key] = wandb.Image(value)
|
||||
def eval_figures(self, step, figures):
|
||||
self.dict_to_figure(f"{self.model_name}_EvalFigures", figures)
|
||||
|
||||
def eval_audios(self, step, audios, sample_rate):
|
||||
self.dict_to_audios(f"{self.model_name}_EvalAudios", audios, sample_rate)
|
||||
|
||||
def test_audios(self, step, audios, sample_rate):
|
||||
self.dict_to_audios(f"{self.model_name}_TestAudios", audios, sample_rate)
|
||||
|
||||
def test_figures(self, step, figures):
|
||||
self.dict_to_figure(f"{self.model_name}_TestFigures", figures)
|
||||
|
||||
def add_text(self, title, text, step):
|
||||
self.log_dict[title] = wandb.HTML(f"<p> {text} </p>")
|
||||
|
||||
def flush(self):
|
||||
if self.run:
|
||||
|
|
Loading…
Reference in New Issue