mirror of https://github.com/coqui-ai/TTS.git
113 lines
4.2 KiB
Python
113 lines
4.2 KiB
Python
from pathlib import Path
|
|
import traceback
|
|
|
|
try:
|
|
import wandb
|
|
from wandb import finish, init # pylint: disable=W0611
|
|
except ImportError:
|
|
wandb = None
|
|
|
|
|
|
class WandbLogger:
|
|
def __init__(self, **kwargs):
|
|
|
|
if not wandb:
|
|
raise Exception("install wandb using `pip install wandb` to use WandbLogger")
|
|
|
|
self.run = None
|
|
self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
|
|
self.model_name = self.run.config.model
|
|
self.log_dict = {}
|
|
|
|
def model_weights(self, model):
|
|
layer_num = 1
|
|
for name, param in model.named_parameters():
|
|
if param.numel() == 1:
|
|
self.dict_to_scalar("weights",{"layer{}-{}/value".format(layer_num, name): param.max()})
|
|
else:
|
|
self.dict_to_scalar("weights", {"layer{}-{}/max".format(layer_num, name): param.max()})
|
|
self.dict_to_scalar("weights", {"layer{}-{}/min".format(layer_num, name): param.min()})
|
|
self.dict_to_scalar("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()})
|
|
self.dict_to_scalar("weights", {"layer{}-{}/std".format(layer_num, name): param.std()})
|
|
'''
|
|
self.writer.add_histogram("layer{}-{}/param".format(layer_num, name), param, step)
|
|
self.writer.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step)
|
|
'''
|
|
layer_num += 1
|
|
|
|
def dict_to_scalar(self, scope_name, stats):
|
|
for key, value in stats.items():
|
|
self.log_dict["{}/{}".format(scope_name, key)] = value
|
|
|
|
def dict_to_figure(self, scope_name, figures):
|
|
for key, value in figures.items():
|
|
self.log_dict["{}/{}".format(scope_name, key)] = wandb.Image(value)
|
|
|
|
def dict_to_audios(self, scope_name, audios, sample_rate):
|
|
for key, value in audios.items():
|
|
if value.dtype == "float16":
|
|
value = value.astype("float32")
|
|
try:
|
|
self.log_dict["{}/{}".format(scope_name, key)] = wandb.Audio(value, sample_rate=sample_rate)
|
|
except RuntimeError:
|
|
traceback.print_exc()
|
|
|
|
|
|
def log(self, log_dict, prefix="", flush=False):
|
|
for key, value in log_dict.items():
|
|
self.log_dict[prefix + key] = value
|
|
if flush: # for cases where you don't want to accumulate data
|
|
self.flush()
|
|
|
|
def train_step_stats(self, step, stats):
|
|
self.dict_to_scalar(f"{self.model_name}_TrainIterStats", stats)
|
|
|
|
def train_epoch_stats(self, step, stats):
|
|
self.dict_to_scalar(f"{self.model_name}_TrainEpochStats", stats)
|
|
|
|
def train_figures(self, step, figures):
|
|
self.dict_to_figure(f"{self.model_name}_TrainFigures", figures)
|
|
|
|
def train_audios(self, step, audios, sample_rate):
|
|
self.dict_to_audios(f"{self.model_name}_TrainAudios", audios, sample_rate)
|
|
|
|
def eval_stats(self, step, stats):
|
|
self.dict_to_scalar(f"{self.model_name}_EvalStats", stats)
|
|
|
|
def eval_figures(self, step, figures):
|
|
self.dict_to_figure(f"{self.model_name}_EvalFigures", figures)
|
|
|
|
def eval_audios(self, step, audios, sample_rate):
|
|
self.dict_to_audios(f"{self.model_name}_EvalAudios", audios, sample_rate)
|
|
|
|
def test_audios(self, step, audios, sample_rate):
|
|
self.dict_to_audios(f"{self.model_name}_TestAudios", audios, sample_rate)
|
|
|
|
def test_figures(self, step, figures):
|
|
self.dict_to_figure(f"{self.model_name}_TestFigures", figures)
|
|
|
|
def add_text(self, title, text, step):
|
|
self.log_dict[title] = wandb.HTML(f"<p> {text} </p>")
|
|
|
|
def flush(self):
|
|
if self.run:
|
|
wandb.log(self.log_dict)
|
|
self.log_dict = {}
|
|
|
|
def finish(self):
|
|
if self.run:
|
|
self.run.finish()
|
|
|
|
def log_artifact(self, file_or_dir, name, artifact_type, aliases=None):
|
|
if not self.run:
|
|
return
|
|
name = "_".join([self.run.id, name])
|
|
artifact = wandb.Artifact(name, type=artifact_type)
|
|
data_path = Path(file_or_dir)
|
|
if data_path.is_dir():
|
|
artifact.add_dir(str(data_path))
|
|
elif data_path.is_file():
|
|
artifact.add_file(str(data_path))
|
|
|
|
self.run.log_artifact(artifact, aliases=aliases)
|