From 4619b84ddbda458b202e74f28db9a02191c0a499 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 2 Jun 2020 17:07:15 +0200 Subject: [PATCH] multiscale subband stft loss and missing tb_logger for eval stats --- vocoder/configs/melgan_config.json | 8 +++++ vocoder/layers/losses.py | 21 ++++++++++-- vocoder/train.py | 53 ++++++++++++++---------------- 3 files changed, 51 insertions(+), 31 deletions(-) diff --git a/vocoder/configs/melgan_config.json b/vocoder/configs/melgan_config.json index b97206fa..4aeba82e 100644 --- a/vocoder/configs/melgan_config.json +++ b/vocoder/configs/melgan_config.json @@ -49,11 +49,13 @@ // LOSS PARAMETERS "use_stft_loss": true, + "use_subband_stft_loss": true, "use_mse_gan_loss": true, "use_hinge_gan_loss": false, "use_feat_match_loss": false, // use only with melgan discriminators "stft_loss_weight": 1, + "subband_stft_loss_weight": 1, "mse_gan_loss_weight": 2.5, "hinge_gan_loss_weight": 2.5, "feat_match_loss_weight": 10.0, @@ -63,6 +65,12 @@ "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240] }, + "subband_stft_loss_params":{ + "n_ffts": [384, 683, 171], + "hop_lengths": [30, 60, 10], + "win_lengths": [150, 300, 60] + }, + "target_loss": "avg_G_loss", // loss value to pick the best model // DISCRIMINATOR diff --git a/vocoder/layers/losses.py b/vocoder/layers/losses.py index 27091c0a..e242d209 100644 --- a/vocoder/layers/losses.py +++ b/vocoder/layers/losses.py @@ -49,7 +49,6 @@ class STFTLoss(nn.Module): loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro") return loss_mag, loss_sc - class MultiScaleSTFTLoss(torch.nn.Module): def __init__(self, n_ffts=[1024, 2048, 512], @@ -73,6 +72,13 @@ class MultiScaleSTFTLoss(torch.nn.Module): return loss_mag, loss_sc +class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss): + def forward(self, y_hat, y): + y_hat = y_hat.view(-1, 1, y_hat.shape[2]) + y = y.view(-1, 1, y.shape[2]) + return super().forward(y_hat.squeeze(1), y.squeeze(1)) + + class MSEGLoss(nn.Module): """ Mean Squared Generator Loss """ def __init__(self,): @@ -145,17 +151,21 @@ class GeneratorLoss(nn.Module): " [!] Cannot use HingeGANLoss and MSEGANLoss together." self.use_stft_loss = C.use_stft_loss + self.use_subband_stft_loss = C.use_subband_stft_loss self.use_mse_gan_loss = C.use_mse_gan_loss self.use_hinge_gan_loss = C.use_hinge_gan_loss self.use_feat_match_loss = C.use_feat_match_loss self.stft_loss_weight = C.stft_loss_weight + self.subband_stft_loss_weight = C.subband_stft_loss_weight self.mse_gan_loss_weight = C.mse_gan_loss_weight self.hinge_gan_loss_weight = C.hinge_gan_loss_weight self.feat_match_loss_weight = C.feat_match_loss_weight if C.use_stft_loss: self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params) + if C.use_subband_stft_loss: + self.subband_stft_loss = MultiScaleSubbandSTFTLoss(**C.subband_stft_loss_params) if C.use_mse_gan_loss: self.mse_loss = MSEGLoss() if C.use_hinge_gan_loss: @@ -163,7 +173,7 @@ class GeneratorLoss(nn.Module): if C.use_feat_match_loss: self.feat_match_loss = MelganFeatureLoss() - def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None): + def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None): loss = 0 return_dict = {} @@ -174,6 +184,13 @@ class GeneratorLoss(nn.Module): return_dict['G_stft_loss_sc'] = stft_loss_sc loss += self.stft_loss_weight * (stft_loss_mg + stft_loss_sc) + # subband STFT Loss + if self.use_subband_stft_loss: + subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub) + return_dict['G_subband_stft_loss_mg'] = subband_stft_loss_mg + return_dict['G_subband_stft_loss_sc'] = subband_stft_loss_sc + loss += self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc) + # Fake Losses if self.use_mse_gan_loss and scores_fake is not None: mse_fake_loss = 0 diff --git a/vocoder/train.py b/vocoder/train.py index 27e14e8a..6341bb6c 100644 --- a/vocoder/train.py +++ b/vocoder/train.py @@ -122,30 +122,27 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, # generator pass optimizer_G.zero_grad() y_hat = model_G(c_G) - - in_real_D = y_hat - in_fake_D = y_G + y_hat_sub = None + y_G_sub = None # PQMF formatting if y_hat.shape[1] > 1: - in_real_D = y_G - in_fake_D = model_G.pqmf_synthesis(y_hat) - y_G = model_G.pqmf_analysis(y_G) - y_hat = y_hat.view(-1, 1, y_hat.shape[2]) - y_G = y_G.view(-1, 1, y_G.shape[2]) + y_hat_sub = y_hat + y_hat = model_G.pqmf_synthesis(y_hat) + y_G_sub = model_G.pqmf_analysis(y_G) if global_step > c.steps_to_start_discriminator: # run D with or without cond. features if len(signature(model_D.forward).parameters) == 2: - D_out_fake = model_D(in_fake_D, c_G) + D_out_fake = model_D(y_hat, c_G) else: - D_out_fake = model_D(in_fake_D) + D_out_fake = model_D(y_hat) D_out_real = None if c.use_feat_match_loss: with torch.no_grad(): - D_out_real = model_D(in_real_D) + D_out_real = model_D(y_G) # format D outputs if isinstance(D_out_fake, tuple): @@ -162,7 +159,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, # compute losses loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, - feats_real) + feats_real, y_hat_sub, y_G_sub) loss_G = loss_G_dict['G_loss'] # optimizer generator @@ -272,12 +269,12 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, model_losses=loss_dict) # compute spectrograms - figures = plot_results(in_fake_D, in_real_D, ap, global_step, + figures = plot_results(y_hat, y_G, ap, global_step, 'train') tb_logger.tb_train_figures(global_step, figures) # Sample audio - sample_voice = in_fake_D[0].squeeze(0).detach().cpu().numpy() + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() tb_logger.tb_train_audios(global_step, {'train/audio': sample_voice}, c.audio["sample_rate"]) @@ -320,23 +317,20 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch): # generator pass y_hat = model_G(c_G) - - in_real_D = y_hat - in_fake_D = y_G + y_hat_sub = None + y_G_sub = None # PQMF formatting if y_hat.shape[1] > 1: - in_real_D = y_G - in_fake_D = model_G.pqmf_synthesis(y_hat) - y_G = model_G.pqmf_analysis(y_G) - y_hat = y_hat.view(-1, 1, y_hat.shape[2]) - y_G = y_G.view(-1, 1, y_G.shape[2]) + y_hat_sub = y_hat + y_hat = model_G.pqmf_synthesis(y_hat) + y_G_sub = model_G.pqmf_analysis(y_G) - D_out_fake = model_D(in_fake_D) + D_out_fake = model_D(y_hat) D_out_real = None if c.use_feat_match_loss: with torch.no_grad(): - D_out_real = model_D(in_real_D) + D_out_real = model_D(y_G) # format D outputs if isinstance(D_out_fake, tuple): @@ -350,7 +344,7 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch): # compute losses loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, - feats_real) + feats_real, y_hat_sub, y_G_sub) loss_dict = dict() for key, value in loss_G_dict.items(): @@ -364,7 +358,7 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch): for key, value in loss_G_dict.items(): update_eval_values['avg_' + key] = value.item() update_eval_values['avg_loader_time'] = loader_time - update_eval_values['avg_step_time'] = step_time + update_eval_values['avgP_step_time'] = step_time keep_avg.update_values(update_eval_values) # print eval stats @@ -372,18 +366,19 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch): c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) # compute spectrograms - figures = plot_results(in_fake_D, in_real_D, ap, global_step, 'eval') + figures = plot_results(y_hat, y_G, ap, global_step, 'eval') tb_logger.tb_eval_figures(global_step, figures) # Sample audio - sample_voice = in_fake_D[0].squeeze(0).detach().cpu().numpy() + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice}, c.audio["sample_rate"]) # synthesize a full voice data_loader.return_segments = False - tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) + tb_logger.tb_evals_stats(global_step, keep_avg.avg_values) + return keep_avg.avg_values