From 2b361f0b50daed9e4555d7c09f371ba201564ed7 Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 4 Jun 2020 09:51:20 +0200 Subject: [PATCH] visualization fix --- vocoder/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vocoder/train.py b/vocoder/train.py index 03e14f4a..a75c6a12 100644 --- a/vocoder/train.py +++ b/vocoder/train.py @@ -120,11 +120,13 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, y_hat = model_G(c_G) y_hat_sub = None y_G_sub = None + y_hat_vis = y_hat # for visualization # PQMF formatting if y_hat.shape[1] > 1: y_hat_sub = y_hat y_hat = model_G.pqmf_synthesis(y_hat) + y_hat_vis = y_hat y_G_sub = model_G.pqmf_analysis(y_G) if global_step > c.steps_to_start_discriminator: @@ -265,12 +267,12 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, model_losses=loss_dict) # compute spectrograms - figures = plot_results(y_hat, y_G, ap, global_step, + figures = plot_results(y_hat_vis, y_G, ap, global_step, 'train') tb_logger.tb_train_figures(global_step, figures) # Sample audio - sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy() tb_logger.tb_train_audios(global_step, {'train/audio': sample_voice}, c.audio["sample_rate"])