From 062e8a0880e895816a66c775aff424c053c9be63 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Thu, 13 Dec 2018 18:18:37 +0100 Subject: [PATCH] logger for tensorboard plotting --- config.json | 1 + requirements.txt | 3 +- train.py | 122 ++++++++++++++++------------------------------- utils/logger.py | 75 +++++++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 83 deletions(-) create mode 100644 utils/logger.py diff --git a/config.json b/config.json index e67d9a4f..bd49f6e8 100644 --- a/config.json +++ b/config.json @@ -40,6 +40,7 @@ "checkpoint": true, "save_step": 5000, "print_step": 10, + "tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. "run_eval": true, "data_path": "../../Data/LJSpeech-1.1/", // can overwritten from command argument diff --git a/requirements.txt b/requirements.txt index 73e5dae7..e49445d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,5 +8,4 @@ tensorboardX matplotlib==2.0.2 Pillow flask -scipy==0.19.0 -lws \ No newline at end of file +scipy==0.19.0 \ No newline at end of file diff --git a/train.py b/train.py index 8fe07ded..3d4212bc 100644 --- a/train.py +++ b/train.py @@ -22,6 +22,7 @@ from models.tacotron import Tacotron from layers.losses import L1LossMasked from utils.audio import AudioProcessor from utils.synthesis import synthesis +from utils.logger import Logger torch.manual_seed(1) use_cuda = torch.cuda.is_available() @@ -169,15 +170,13 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, avg_step_time += step_time # Plot Training Iter Stats - tb.add_scalar('TrainIterLoss/TotalLoss', loss.item(), current_step) - tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.item(), - current_step) - tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.item(), current_step) - tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'], - current_step) - tb.add_scalar('Params/GradNorm', grad_norm, current_step) - tb.add_scalar('Params/GradNormSt', grad_norm_st, current_step) - tb.add_scalar('Time/StepTime', step_time, current_step) + iter_stats = {"loss_posnet": linear_loss.item(), + "loss_decoder": mel_loss.item(), + "lr": current_lr, + "grad_norm": grad_norm, + "grad_norm_st": grad_norm_st, + "step_time": step_time} + tb_logger.tb_train_iter_stats(current_step, iter_stats) if current_step % c.save_step == 0: if c.checkpoint: @@ -189,28 +188,17 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, # Diagnostic visualizations const_spec = linear_output[0].data.cpu().numpy() gt_spec = linear_input[0].data.cpu().numpy() - - const_spec = plot_spectrogram(const_spec, ap) - gt_spec = plot_spectrogram(gt_spec, ap) - tb.add_figure('Visual/Reconstruction', const_spec, current_step) - tb.add_figure('Visual/GroundTruth', gt_spec, current_step) - align_img = alignments[0].data.cpu().numpy() - align_img = plot_alignment(align_img) - tb.add_figure('Visual/Alignment', align_img, current_step) + + figures = {"prediction": plot_spectrogram(const_spec, ap), + "ground_truth": plot_spectrogram(gt_spec, ap), + "alignment": plot_alignment(align_img)} + tb_logger.tb_train_figures(figures, current_step) # Sample audio - audio_signal = linear_output[0].data.cpu().numpy() - ap.griffin_lim_iters = 60 - audio_signal = ap.inv_spectrogram(audio_signal.T) - try: - tb.add_audio( - 'SampleAudio', - audio_signal, - current_step, - sample_rate=c.sample_rate) - except: - pass + tb_logger.tb_train_audios(current_step, + {'TrainAudio': ap.inv_spectrogram(const_spec.T)}, + c.sample_rate) avg_linear_loss /= (num_iter + 1) avg_mel_loss /= (num_iter + 1) @@ -229,12 +217,13 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, flush=True) # Plot Training Epoch Stats - tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step) - tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step) - tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step) - tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step) - tb.add_scalar('Time/EpochTime', epoch_time, epoch) - epoch_time = 0 + epoch_stats = {"loss_postnet": avg_linear_loss, + "loss_decoder": avg_mel_loss, + "stop_loss": avg_stop_loss, + "epoch_time": epoch_time} + tb_logger.tb_train_epoch_stats(current_step, epoch_stats) + if c.tb_model_param_stats: + tb_logger.tb_model_weights(model, current_step) return avg_linear_loss, current_step @@ -316,74 +305,45 @@ def evaluate(model, criterion, criterion_st, ap, current_step): gt_spec = linear_input[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy() - const_spec = plot_spectrogram(const_spec, ap) - gt_spec = plot_spectrogram(gt_spec, ap) - align_img = plot_alignment(align_img) - - tb.add_figure('ValVisual/Reconstruction', const_spec, current_step) - tb.add_figure('ValVisual/GroundTruth', gt_spec, current_step) - tb.add_figure('ValVisual/ValidationAlignment', align_img, - current_step) + eval_figures = {"prediction": plot_spectrogram(const_spec, ap), + "ground_truth": plot_spectrogram(gt_spec, ap), + "alignment": plot_alignment(align_img)} + tb_logger.tb_eval_figures(current_step, eval_figures) # Sample audio - audio_signal = linear_output[idx].data.cpu().numpy() - ap.griffin_lim_iters = 60 - audio_signal = ap.inv_spectrogram(audio_signal.T) - try: - tb.add_audio( - 'ValSampleAudio', - audio_signal, - current_step, - sample_rate=c.audio["sample_rate"]) - except: - # sometimes audio signal is out of boundaries - pass + tb_logger.tb_eval_audios(current_step, {"ValAudio": ap.inv_spectrogram(const_spec.T)}, c.audio["sample_rate"]) # compute average losses avg_linear_loss /= (num_iter + 1) avg_mel_loss /= (num_iter + 1) avg_stop_loss /= (num_iter + 1) - avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss - # Plot Learning Stats - tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, - current_step) - tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, - current_step) - tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step) - tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, - current_step) + # Plot Validation Stats + epoch_stats = {"loss_postnet": avg_linear_loss, + "loss_decoder": avg_mel_loss, + "stop_loss": avg_stop_loss} + tb_logger.tb_eval_stats(current_step, epoch_stats) # test sentences - ap.griffin_lim_iters = 60 + test_audios = {} + test_figures = {} for idx, test_sentence in enumerate(test_sentences): try: wav, alignment, linear_spec, _, stop_tokens = synthesis( model, test_sentence, c, use_cuda, ap) - file_path = os.path.join(AUDIO_PATH, str(current_step)) os.makedirs(file_path, exist_ok=True) file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) ap.save_wav(wav, file_path) - - wav_name = 'TestSentences/{}'.format(idx) - tb.add_audio( - wav_name, - wav, - current_step, - sample_rate=c.audio['sample_rate']) - - linear_spec = plot_spectrogram(linear_spec, ap) - align_img = plot_alignment(alignment) - tb.add_figure('TestSentences/{}_Spectrogram'.format(idx), - linear_spec, current_step) - tb.add_figure('TestSentences/{}_Alignment'.format(idx), align_img, - current_step) + test_audios['{}-audio'.format(idx)] = wav + test_figures['{}-prediction'.format(idx)] = plot_spectrogram(linear_spec, ap) + test_figures['{}-alignment'.format(idx)] = plot_alignment(alignment) except: print(" !! Error creating Test Sentence -", idx) traceback.print_exc() - pass + tb_logger.tb_test_audios(current_step, test_audios, c.audio['sample_rate']) + tb_logger.tb_test_figures(current_step, test_figures) return avg_linear_loss @@ -496,7 +456,7 @@ if __name__ == '__main__': # setup tensorboard LOG_DIR = OUT_PATH - tb = SummaryWriter(LOG_DIR) + tb_logger = Logger(LOG_DIR) # Conditional imports preprocessor = importlib.import_module('datasets.preprocess') diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 00000000..c8cfcf28 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,75 @@ +import traceback +from tensorboardX import SummaryWriter + + +class Logger(object): + def __init__(self, log_dir): + self.writer = SummaryWriter(log_dir) + self.train_stats = {} + self.eval_stats = {} + + def tb_model_weights(self, model, step): + layer_num = 1 + for name, param in model.named_parameters(): + self.writer.add_scalar( + "layer{}-ModelParams/{}/max".format(layer_num, name), + param.max(), step) + self.writer.add_scalar( + "layer{}-ModelParams/{}/min".format(layer_num, name), + param.min(), step) + self.writer.add_scalar( + "layer{}-ModelParams/{}/mean".format(layer_num, name), + param.mean(), step) + self.writer.add_scalar( + "layer{}-ModelParams/{}/std".format(layer_num, name), + param.std(), step) + self.writer.add_histogram( + "layer{}-{}/param".format(layer_num, name), param, step) + self.writer.add_histogram( + "layer{}-{}/grad".format(layer_num, name), param.grad, step) + layer_num += 1 + + def dict_to_tb_scalar(self, scope_name, stats, step): + for key, value in stats.items(): + self.writer.add_scalar('{}/{}'.format(scope_name, key), value, step) + + def dict_to_tb_figure(self, scope_name, figures, step): + for key, value in figures.items(): + self.writer.add_figure('{}/{}'.format(scope_name, key), value, step) + + def dict_to_tb_audios(self, scope_name, audios, step, sample_rate): + for key, value in audios.items(): + try: + self.writer.add_audio('{}/{}'.format(scope_name, key), value, step, sample_rate=sample_rate) + except: + traceback.print_exc() + + def tb_train_iter_stats(self, step, stats): + self.dict_to_tb_scalar("TrainIterStats", stats, step) + + def tb_train_epoch_stats(self, step, stats): + self.dict_to_tb_scalar("TrainEpochStats", stats, step) + + def tb_train_figures(self, step, figures): + self.dict_to_tb_figure("TrainFigures", figures, step) + + def tb_train_audios(self, step, audios, sample_rate): + self.dict_to_tb_audios("TrainAudios", audios, step, sample_rate) + + def tb_eval_stats(self, step, stats): + self.dict_to_tb_scalar("EvalStats", stats, step) + + def tb_eval_figures(self, step, figures): + self.dict_to_tb_figure("EvalFigures", figures, step) + + def tb_eval_audios(self, step, audios, sample_rate): + self.dict_to_tb_audios("EvalAudios", audios, step, sample_rate) + + def tb_test_audios(self, step, audios, sample_rate): + self.dict_to_tb_audios("TestAudios", audios, step, sample_rate) + + def tb_test_figures(self, step, figures): + self.dict_to_tb_figure("TestFigures", figures, step) + + + \ No newline at end of file