From 604453bcb770149a54a0deabcc00baa888a5dd02 Mon Sep 17 00:00:00 2001 From: erogol Date: Sun, 31 May 2020 18:56:57 +0200 Subject: [PATCH] add multiband melgan config and rename loss weights for vocoder --- vocoder/configs/melgan_config.json | 8 ++++---- vocoder/layers/losses.py | 24 ++++++++++++------------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/vocoder/configs/melgan_config.json b/vocoder/configs/melgan_config.json index 7da42f75..b97206fa 100644 --- a/vocoder/configs/melgan_config.json +++ b/vocoder/configs/melgan_config.json @@ -53,10 +53,10 @@ "use_hinge_gan_loss": false, "use_feat_match_loss": false, // use only with melgan discriminators - "stft_loss_alpha": 1, - "mse_gan_loss_alpha": 1, - "hinge_gan_loss_alpha": 1, - "feat_match_loss_alpha": 10.0, + "stft_loss_weight": 1, + "mse_gan_loss_weight": 2.5, + "hinge_gan_loss_weight": 2.5, + "feat_match_loss_weight": 10.0, "stft_loss_params": { "n_ffts": [1024, 2048, 512], diff --git a/vocoder/layers/losses.py b/vocoder/layers/losses.py index 11985629..27091c0a 100644 --- a/vocoder/layers/losses.py +++ b/vocoder/layers/losses.py @@ -149,10 +149,10 @@ class GeneratorLoss(nn.Module): self.use_hinge_gan_loss = C.use_hinge_gan_loss self.use_feat_match_loss = C.use_feat_match_loss - self.stft_loss_alpha = C.stft_loss_alpha - self.mse_gan_loss_alpha = C.mse_gan_loss_alpha - self.hinge_gan_loss_alpha = C.hinge_gan_loss_alpha - self.feat_match_loss_alpha = C.feat_match_loss_alpha + self.stft_loss_weight = C.stft_loss_weight + self.mse_gan_loss_weight = C.mse_gan_loss_weight + self.hinge_gan_loss_weight = C.hinge_gan_loss_weight + self.feat_match_loss_weight = C.feat_match_loss_weight if C.use_stft_loss: self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params) @@ -172,7 +172,7 @@ class GeneratorLoss(nn.Module): stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat.squeeze(1), y.squeeze(1)) return_dict['G_stft_loss_mg'] = stft_loss_mg return_dict['G_stft_loss_sc'] = stft_loss_sc - loss += self.stft_loss_alpha * (stft_loss_mg + stft_loss_sc) + loss += self.stft_loss_weight * (stft_loss_mg + stft_loss_sc) # Fake Losses if self.use_mse_gan_loss and scores_fake is not None: @@ -185,7 +185,7 @@ class GeneratorLoss(nn.Module): fake_loss = self.mse_loss(scores_fake) mse_fake_loss = fake_loss return_dict['G_mse_fake_loss'] = mse_fake_loss - loss += self.mse_gan_loss_alpha * mse_fake_loss + loss += self.mse_gan_loss_weight * mse_fake_loss if self.use_hinge_gan_loss and not scores_fake is not None: hinge_fake_loss = 0 @@ -197,13 +197,13 @@ class GeneratorLoss(nn.Module): fake_loss = self.hinge_loss(scores_fake) hinge_fake_loss = fake_loss return_dict['G_hinge_fake_loss'] = hinge_fake_loss - loss += self.hinge_gan_loss_alpha * hinge_fake_loss + loss += self.hinge_gan_loss_weight * hinge_fake_loss # Feature Matching Loss if self.use_feat_match_loss and not feats_fake: feat_match_loss = self.feat_match_loss(feats_fake, feats_real) return_dict['G_feat_match_loss'] = feat_match_loss - loss += self.feat_match_loss_alpha * feat_match_loss + loss += self.feat_match_loss_weight * feat_match_loss return_dict['G_loss'] = loss return return_dict @@ -217,8 +217,8 @@ class DiscriminatorLoss(nn.Module): self.use_mse_gan_loss = C.use_mse_gan_loss self.use_hinge_gan_loss = C.use_hinge_gan_loss - self.mse_gan_loss_alpha = C.mse_gan_loss_alpha - self.hinge_gan_loss_alpha = C.hinge_gan_loss_alpha + self.mse_gan_loss_weight = C.mse_gan_loss_weight + self.hinge_gan_loss_weight = C.hinge_gan_loss_weight if C.use_mse_gan_loss: self.mse_loss = MSEDLoss() @@ -247,7 +247,7 @@ class DiscriminatorLoss(nn.Module): return_dict['D_mse_gan_loss'] = mse_gan_loss return_dict['D_mse_gan_real_loss'] = mse_gan_real_loss return_dict['D_mse_gan_fake_loss'] = mse_gan_fake_loss - loss += self.mse_gan_loss_alpha * mse_gan_loss + loss += self.mse_gan_loss_weight * mse_gan_loss if self.use_hinge_gan_loss: hinge_gan_loss = 0 @@ -267,7 +267,7 @@ class DiscriminatorLoss(nn.Module): return_dict['D_hinge_gan_loss'] = hinge_gan_loss return_dict['D_hinge_gan_real_loss'] = hinge_gan_real_loss return_dict['D_hinge_gan_fake_loss'] = hinge_gan_fake_loss - loss += self.hinge_gan_loss_alpha * hinge_gan_loss + loss += self.hinge_gan_loss_weight * hinge_gan_loss return_dict['D_loss'] = loss return return_dict \ No newline at end of file