bug fix and run desc in tensorboard

This commit is contained in:
erogol 2020-03-10 22:38:51 +01:00
parent 3472a41255
commit 2a15e39166
3 changed files with 10 additions and 4 deletions

View File

@ -1,7 +1,7 @@
{ {
"model": "Tacotron2", // one of the model in models/ "model": "Tacotron2", // one of the model in models/
"run_name": "ljspeech-stft_params", "run_name": "ljspeech",
"run_description": "tacotron2 cosntant stf parameters", "run_description": "tacotron2 with guided attention and -1 1 normalization and no preemphasis",
// AUDIO PARAMETERS // AUDIO PARAMETERS
"audio":{ "audio":{

View File

@ -47,7 +47,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
dataset = MyDataset( dataset = MyDataset(
r, r,
c.text_cleaner, c.text_cleaner,
compute_linear_spec=True if c.model.lower() is 'tacotron' else False compute_linear_spec=True if c.model.lower() is 'tacotron' else False,
meta_data=meta_data_eval if is_val else meta_data_train, meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap, ap=ap,
tp=c.characters if 'characters' in c.keys() else None, tp=c.characters if 'characters' in c.keys() else None,
@ -410,7 +410,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
loss_dict['ga_loss'].item(), loss_dict['ga_loss'].item(),
keep_avg['avg_ga_loss'], keep_avg['avg_ga_loss'],
align_score, keep_avg['avg_align_score']), align_score, keep_avg['avg_align_score']),
flush=Tr ue) flush=True)
if args.rank == 0: if args.rank == 0:
# Diagnostic visualizations # Diagnostic visualizations
@ -696,6 +696,9 @@ if __name__ == '__main__':
LOG_DIR = OUT_PATH LOG_DIR = OUT_PATH
tb_logger = Logger(LOG_DIR) tb_logger = Logger(LOG_DIR)
# write model desc to tensorboard
tb_logger.tb_add_text('model-description', c['run_description'], 0)
try: try:
main(args) main(args)
except KeyboardInterrupt: except KeyboardInterrupt:

View File

@ -75,3 +75,6 @@ class Logger(object):
def tb_test_figures(self, step, figures): def tb_test_figures(self, step, figures):
self.dict_to_tb_figure("TestFigures", figures, step) self.dict_to_tb_figure("TestFigures", figures, step)
def tb_add_text(self, title, text, step):
self.writer.add_text(title, text, step)