mirror of https://github.com/coqui-ai/TTS.git
WandbLogger
This commit is contained in:
parent
06018251e6
commit
f5e50ad502
|
@ -38,7 +38,7 @@ from TTS.utils.generic_utils import (
|
||||||
to_cuda,
|
to_cuda,
|
||||||
)
|
)
|
||||||
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
||||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger
|
||||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||||
|
@ -52,6 +52,7 @@ if platform.system() != "Windows":
|
||||||
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||||
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
|
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
|
||||||
|
|
||||||
|
|
||||||
if is_apex_available():
|
if is_apex_available():
|
||||||
from apex import amp
|
from apex import amp
|
||||||
|
|
||||||
|
@ -92,6 +93,7 @@ class Trainer:
|
||||||
output_path: str,
|
output_path: str,
|
||||||
c_logger: ConsoleLogger = None,
|
c_logger: ConsoleLogger = None,
|
||||||
tb_logger: TensorboardLogger = None,
|
tb_logger: TensorboardLogger = None,
|
||||||
|
wandb_logger: WandbLogger = None,
|
||||||
model: nn.Module = None,
|
model: nn.Module = None,
|
||||||
cudnn_benchmark: bool = False,
|
cudnn_benchmark: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -119,6 +121,9 @@ class Trainer:
|
||||||
tb_logger (TensorboardLogger, optional): Tensorboard logger. If not provided, the default logger is used.
|
tb_logger (TensorboardLogger, optional): Tensorboard logger. If not provided, the default logger is used.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
|
||||||
|
wandb_logger (WandbLogger, optional): W&B logger. If not provided, the default logger is used.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
|
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
|
||||||
initializes a model from the provided config. Defaults to None.
|
initializes a model from the provided config. Defaults to None.
|
||||||
|
|
||||||
|
@ -167,6 +172,16 @@ class Trainer:
|
||||||
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||||
else:
|
else:
|
||||||
self.tb_logger = tb_logger
|
self.tb_logger = tb_logger
|
||||||
|
|
||||||
|
if wandb_logger is None:
|
||||||
|
self.wandb_logger = WandbLogger(
|
||||||
|
project=os.path(output_path).stem,
|
||||||
|
name=config.model,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.wandb_logger = wandb_logger
|
||||||
|
|
||||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||||
self._setup_logger_config(log_file)
|
self._setup_logger_config(log_file)
|
||||||
|
|
||||||
|
@ -627,6 +642,7 @@ class Trainer:
|
||||||
# reduce TB load and don't log every step
|
# reduce TB load and don't log every step
|
||||||
if self.total_steps_done % self.config.tb_plot_step == 0:
|
if self.total_steps_done % self.config.tb_plot_step == 0:
|
||||||
self.tb_logger.tb_train_step_stats(self.total_steps_done, loss_dict)
|
self.tb_logger.tb_train_step_stats(self.total_steps_done, loss_dict)
|
||||||
|
self.wandb_logger.log(loss_dict, "train/", flush=True)
|
||||||
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
|
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
|
||||||
if self.config.checkpoint:
|
if self.config.checkpoint:
|
||||||
# checkpoint the model
|
# checkpoint the model
|
||||||
|
@ -649,10 +665,14 @@ class Trainer:
|
||||||
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
||||||
if figures is not None:
|
if figures is not None:
|
||||||
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
|
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
|
||||||
|
self.wandb_logger.log_figures(figures, "train/")
|
||||||
if audios is not None:
|
if audios is not None:
|
||||||
self.tb_logger.tb_train_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
self.tb_logger.tb_train_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||||
|
self.wandb_logger.log_audios(audios, self.ap.sample_rate, "train/")
|
||||||
|
|
||||||
self.total_steps_done += 1
|
self.total_steps_done += 1
|
||||||
self.callbacks.on_train_step_end()
|
self.callbacks.on_train_step_end()
|
||||||
|
self.wandb_logger.flush()
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_epoch(self) -> None:
|
def train_epoch(self) -> None:
|
||||||
|
@ -678,6 +698,7 @@ class Trainer:
|
||||||
epoch_stats = {"epoch_time": epoch_time}
|
epoch_stats = {"epoch_time": epoch_time}
|
||||||
epoch_stats.update(self.keep_avg_train.avg_values)
|
epoch_stats.update(self.keep_avg_train.avg_values)
|
||||||
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
|
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
|
||||||
|
self.wandb_logger.log_scalars(epoch_stats, "train/")
|
||||||
if self.config.tb_model_param_stats:
|
if self.config.tb_model_param_stats:
|
||||||
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
||||||
# scheduler step after the epoch
|
# scheduler step after the epoch
|
||||||
|
@ -781,9 +802,12 @@ class Trainer:
|
||||||
figures, audios = self.model.eval_log(self.ap, batch, outputs)
|
figures, audios = self.model.eval_log(self.ap, batch, outputs)
|
||||||
if figures is not None:
|
if figures is not None:
|
||||||
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
||||||
|
self.wandb_logger.log_figures(figures, "eval/")
|
||||||
if audios is not None:
|
if audios is not None:
|
||||||
self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||||
|
self.wandb_logger.log_audios(audios, self.ap.sample_rate, "eval/")
|
||||||
self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
|
self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
|
||||||
|
self.wandb_logger.log_scalars(self.keep_avg_eval.avg_values, "eval/")
|
||||||
|
|
||||||
def test_run(self) -> None:
|
def test_run(self) -> None:
|
||||||
"""Run test and log the results. Test run must be defined by the model.
|
"""Run test and log the results. Test run must be defined by the model.
|
||||||
|
@ -803,6 +827,8 @@ class Trainer:
|
||||||
figures, audios = self.model.test_run(self.ap)
|
figures, audios = self.model.test_run(self.ap)
|
||||||
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||||
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
|
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
|
||||||
|
self.wandb_logger.log_audios(audios, self.config.audio["sample_rate"], "test/")
|
||||||
|
self.wandb_logger.log_figures(figures, "test/")
|
||||||
|
|
||||||
def _fit(self) -> None:
|
def _fit(self) -> None:
|
||||||
"""🏃 train -> evaluate -> test for the number of epochs."""
|
"""🏃 train -> evaluate -> test for the number of epochs."""
|
||||||
|
@ -1066,6 +1092,7 @@ def process_args(args, config=None):
|
||||||
logging to the console.
|
logging to the console.
|
||||||
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
|
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
|
||||||
the TensorBoard logging.
|
the TensorBoard logging.
|
||||||
|
wandb_logger (TTS.utils.tensorboard.WandbLogger): Class that does the W&B Loggin
|
||||||
|
|
||||||
TODO:
|
TODO:
|
||||||
- Interactive config definition.
|
- Interactive config definition.
|
||||||
|
@ -1079,14 +1106,17 @@ def process_args(args, config=None):
|
||||||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||||
if not args.best_path:
|
if not args.best_path:
|
||||||
args.best_path = best_model
|
args.best_path = best_model
|
||||||
# init config if not already defined
|
# setup output paths and read configs
|
||||||
|
if config is None:
|
||||||
|
config = load_config(args.config_path)
|
||||||
|
# init config
|
||||||
if config is None:
|
if config is None:
|
||||||
if args.config_path:
|
if args.config_path:
|
||||||
# init from a file
|
# init from a file
|
||||||
config = load_config(args.config_path)
|
config = load_config(args.config_path)
|
||||||
else:
|
else:
|
||||||
# init from console args
|
# init from console args
|
||||||
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
|
from TTS.config.shared_configs import BaseTrainingConfig
|
||||||
|
|
||||||
config_base = BaseTrainingConfig()
|
config_base = BaseTrainingConfig()
|
||||||
config_base.parse_known_args(coqpit_overrides)
|
config_base.parse_known_args(coqpit_overrides)
|
||||||
|
@ -1101,6 +1131,7 @@ def process_args(args, config=None):
|
||||||
audio_path = os.path.join(experiment_path, "test_audios")
|
audio_path = os.path.join(experiment_path, "test_audios")
|
||||||
# setup rank 0 process in distributed training
|
# setup rank 0 process in distributed training
|
||||||
tb_logger = None
|
tb_logger = None
|
||||||
|
wandb_logger = None
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
new_fields = {}
|
new_fields = {}
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
|
@ -1116,8 +1147,15 @@ def process_args(args, config=None):
|
||||||
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
||||||
# write model desc to tensorboard
|
# write model desc to tensorboard
|
||||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||||
|
|
||||||
|
wandb_logger = WandbLogger(
|
||||||
|
project=config.model,
|
||||||
|
name=os.path.basename(experiment_path),
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
c_logger = ConsoleLogger()
|
c_logger = ConsoleLogger()
|
||||||
return config, experiment_path, audio_path, c_logger, tb_logger
|
return config, experiment_path, audio_path, c_logger, tb_logger, wandb_logger
|
||||||
|
|
||||||
|
|
||||||
def init_arguments():
|
def init_arguments():
|
||||||
|
@ -1133,5 +1171,5 @@ def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
|
||||||
else:
|
else:
|
||||||
parser = init_arguments()
|
parser = init_arguments()
|
||||||
args = parser.parse_known_args()
|
args = parser.parse_known_args()
|
||||||
config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, config)
|
config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger = process_args(args, config)
|
||||||
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger
|
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
from TTS.utils.logging.console_logger import ConsoleLogger
|
from TTS.utils.logging.console_logger import ConsoleLogger
|
||||||
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
||||||
|
from TTS.utils.logging.wandb_logger import WandbLogger
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
from wandb import init, finish
|
||||||
|
except ImportError:
|
||||||
|
wandb = None
|
||||||
|
|
||||||
|
|
||||||
|
class WandbLogger:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.run = None
|
||||||
|
if wandb:
|
||||||
|
self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
|
||||||
|
self.log_dict = {}
|
||||||
|
|
||||||
|
def log(self, log_dict, prefix="", flush=False):
|
||||||
|
"""
|
||||||
|
This function accumulates data in self.log_dict. If flush is set.
|
||||||
|
the accumulated metrics will be logged directly to wandb dashboard.
|
||||||
|
"""
|
||||||
|
for key, value in log_dict.items():
|
||||||
|
self.log_dict[prefix + key] = value
|
||||||
|
if flush: # for cases where you don't want to accumulate data
|
||||||
|
self.flush()
|
||||||
|
|
||||||
|
def log_scalars(self, log_dict, prefix=""):
|
||||||
|
if not self.run:
|
||||||
|
return
|
||||||
|
|
||||||
|
for key, value in log_dict.items():
|
||||||
|
self.log_dict[prefix + key] = value
|
||||||
|
|
||||||
|
def log_audios(self, log_dict, sample_rate, prefix=""):
|
||||||
|
if not self.run:
|
||||||
|
return
|
||||||
|
|
||||||
|
prefix = "audios/" + prefix
|
||||||
|
for key, value in log_dict.items():
|
||||||
|
self.log_dict[prefix + key] = wandb.Audio(value, sample_rate=int(sample_rate))
|
||||||
|
|
||||||
|
def log_figures(self, log_dict, prefix=""):
|
||||||
|
if not self.run:
|
||||||
|
return
|
||||||
|
|
||||||
|
prefix = "figures/" + prefix
|
||||||
|
for key, value in log_dict.items():
|
||||||
|
self.log_dict[prefix + key] = wandb.Image(value)
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
if self.run:
|
||||||
|
wandb.log(self.log_dict)
|
||||||
|
self.log_dict = {}
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
"""
|
||||||
|
Finish this W&B run
|
||||||
|
"""
|
||||||
|
self.run.finish()
|
||||||
|
|
||||||
|
def log_artifact(self, file_or_dir, name, type, aliases=[]):
|
||||||
|
artifact = wandb.Artifact(name, type=type)
|
||||||
|
data_path = Path(file_or_dir)
|
||||||
|
if data_path.is_dir():
|
||||||
|
artifact.add_dir(str(data_path))
|
||||||
|
elif data_path.is_file():
|
||||||
|
artifact.add_file(str(data_path))
|
||||||
|
|
||||||
|
self.run.log_artifact(artifact, aliases=aliases)
|
|
@ -25,6 +25,6 @@ config = AlignTTSConfig(
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -25,6 +25,6 @@ config = GlowTTSConfig(
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -24,6 +24,6 @@ config = HifiganConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -24,6 +24,6 @@ config = MultibandMelganConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -24,6 +24,6 @@ config = UnivnetConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -22,6 +22,6 @@ config = WavegradConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -24,6 +24,6 @@ config = WavernnConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=True)
|
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger, cudnn_benchmark=True)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
Loading…
Reference in New Issue