From 6b1de26869251a3f0af0d22f9f1fe23d0556cc76 Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 3 Jun 2020 12:16:08 +0200 Subject: [PATCH] correct loss normalization and function refactoring --- vocoder/layers/losses.py | 141 +++++++++++++++++++++------------------ 1 file changed, 77 insertions(+), 64 deletions(-) diff --git a/vocoder/layers/losses.py b/vocoder/layers/losses.py index 1c3d4442..fb4e85d4 100644 --- a/vocoder/layers/losses.py +++ b/vocoder/layers/losses.py @@ -6,6 +6,7 @@ from torch.nn import functional as F class TorchSTFT(): def __init__(self, n_fft, hop_length, win_length, window='hann_window'): + """ Torch based STFT operation """ self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length @@ -33,6 +34,7 @@ class TorchSTFT(): class STFTLoss(nn.Module): + """ Single scale STFT Loss """ def __init__(self, n_fft, hop_length, win_length): super(STFTLoss, self).__init__() self.n_fft = n_fft @@ -50,6 +52,7 @@ class STFTLoss(nn.Module): return loss_mag, loss_sc class MultiScaleSTFTLoss(torch.nn.Module): + """ Multi scale STFT loss """ def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), @@ -73,6 +76,7 @@ class MultiScaleSTFTLoss(torch.nn.Module): class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss): + """ Multiscale STFT loss for multi band model outputs """ def forward(self, y_hat, y): y_hat = y_hat.view(-1, 1, y_hat.shape[2]) y = y.view(-1, 1, y.shape[2]) @@ -121,16 +125,62 @@ class MelganFeatureLoss(nn.Module): loss_feats = 0 for fake_feat, real_feat in zip(fake_feats, real_feats): loss_feats += torch.mean(torch.abs(fake_feat - real_feat)) + loss_feats /= len(fake_feats) + len(real_feats) return loss_feats -################################## +##################################### # LOSS WRAPPERS +##################################### + + +def _apply_G_adv_loss(scores_fake, loss_func): + """ Compute G adversarial loss function + and normalize values """ + adv_loss = 0 + if isinstance(scores_fake, list): + for score_fake in scores_fake: + fake_loss = loss_func(score_fake) + adv_loss += fake_loss + adv_loss /= len(scores_fake) + else: + fake_loss = loss_func(scores_fake) + adv_loss = fake_loss + return adv_loss + + +def _apply_D_loss(scores_fake, scores_real, loss_func): + """ Compute D loss func and normalize loss values """ + loss = 0 + real_loss = 0 + fake_loss = 0 + if isinstance(scores_fake, list): + # multi-scale loss + for score_fake, score_real in zip(scores_fake, scores_real): + total_loss, real_loss, fake_loss = loss_func(score_fake, score_real) + loss += total_loss + real_loss += real_loss + fake_loss += fake_loss + # normalize loss values with number of scales + loss /= len(scores_fake) + real_loss /= len(scores_real) + fake_loss /= len(scores_fake) + else: + # single scale loss + total_loss, real_loss, fake_loss = loss_func(scores_fake, scores_real) + loss = total_loss + return loss, real_loss, fake_loss + + +################################## +# MODEL LOSSES ################################## class GeneratorLoss(nn.Module): def __init__(self, C): + """ Compute Generator Loss values depending on training + configuration """ super(GeneratorLoss, self).__init__() assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\ " [!] Cannot use HingeGANLoss and MSEGANLoss together." @@ -159,7 +209,8 @@ class GeneratorLoss(nn.Module): self.feat_match_loss = MelganFeatureLoss() def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None): - loss = 0 + gen_loss = 0 + adv_loss = 0 return_dict = {} # STFT Loss @@ -167,50 +218,41 @@ class GeneratorLoss(nn.Module): stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat.squeeze(1), y.squeeze(1)) return_dict['G_stft_loss_mg'] = stft_loss_mg return_dict['G_stft_loss_sc'] = stft_loss_sc - loss += self.stft_loss_weight * (stft_loss_mg + stft_loss_sc) + gen_loss += self.stft_loss_weight * (stft_loss_mg + stft_loss_sc) # subband STFT Loss if self.use_subband_stft_loss: subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub) return_dict['G_subband_stft_loss_mg'] = subband_stft_loss_mg return_dict['G_subband_stft_loss_sc'] = subband_stft_loss_sc - loss += self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc) + gen_loss += self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc) - # Fake Losses + # multiscale MSE adversarial loss if self.use_mse_gan_loss and scores_fake is not None: - mse_fake_loss = 0 - if isinstance(scores_fake, list): - for score_fake in scores_fake: - fake_loss = self.mse_loss(score_fake) - mse_fake_loss += fake_loss - else: - fake_loss = self.mse_loss(scores_fake) - mse_fake_loss = fake_loss + mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss) return_dict['G_mse_fake_loss'] = mse_fake_loss - loss += self.mse_gan_loss_weight * mse_fake_loss + adv_loss += self.mse_gan_loss_weight * mse_fake_loss + # multiscale Hinge adversarial loss if self.use_hinge_gan_loss and not scores_fake is not None: - hinge_fake_loss = 0 - if isinstance(scores_fake, list): - for score_fake in scores_fake: - fake_loss = self.hinge_loss(score_fake) - hinge_fake_loss += fake_loss - else: - fake_loss = self.hinge_loss(scores_fake) - hinge_fake_loss = fake_loss + hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss) return_dict['G_hinge_fake_loss'] = hinge_fake_loss - loss += self.hinge_gan_loss_weight * hinge_fake_loss + adv_loss += self.hinge_gan_loss_weight * hinge_fake_loss # Feature Matching Loss if self.use_feat_match_loss and not feats_fake: feat_match_loss = self.feat_match_loss(feats_fake, feats_real) return_dict['G_feat_match_loss'] = feat_match_loss - loss += self.feat_match_loss_weight * feat_match_loss - return_dict['G_loss'] = loss + adv_loss += self.feat_match_loss_weight * feat_match_loss + return_dict['G_loss'] = gen_loss + adv_loss + return_dict['G_gen_loss'] = gen_loss + return_dict['G_adv_loss'] = adv_loss return return_dict class DiscriminatorLoss(nn.Module): + """ Compute Discriminator Loss values depending on training + configuration """ def __init__(self, C): super(DiscriminatorLoss, self).__init__() assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\ @@ -219,9 +261,6 @@ class DiscriminatorLoss(nn.Module): self.use_mse_gan_loss = C.use_mse_gan_loss self.use_hinge_gan_loss = C.use_hinge_gan_loss - self.mse_gan_loss_weight = C.mse_gan_loss_weight - self.hinge_gan_loss_weight = C.hinge_gan_loss_weight - if C.use_mse_gan_loss: self.mse_loss = MSEDLoss() if C.use_hinge_gan_loss: @@ -232,44 +271,18 @@ class DiscriminatorLoss(nn.Module): return_dict = {} if self.use_mse_gan_loss: - mse_gan_loss = 0 - mse_gan_real_loss = 0 - mse_gan_fake_loss = 0 - if isinstance(scores_fake, list): - for score_fake, score_real in zip(scores_fake, scores_real): - total_loss, real_loss, fake_loss = self.mse_loss(score_fake, score_real) - mse_gan_loss += total_loss - mse_gan_real_loss += real_loss - mse_gan_fake_loss += fake_loss - else: - total_loss, real_loss, fake_loss = self.mse_loss(scores_fake, scores_real) - mse_gan_loss = total_loss - mse_gan_real_loss = real_loss - mse_gan_fake_loss = fake_loss - return_dict['D_mse_gan_loss'] = mse_gan_loss - return_dict['D_mse_gan_real_loss'] = mse_gan_real_loss - return_dict['D_mse_gan_fake_loss'] = mse_gan_fake_loss - loss += self.mse_gan_loss_weight * mse_gan_loss + mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(scores_fake, scores_real, self.mse_loss) + return_dict['D_mse_gan_loss'] = mse_D_loss + return_dict['D_mse_gan_real_loss'] = mse_D_real_loss + return_dict['D_mse_gan_fake_loss'] = mse_D_fake_loss + loss += mse_D_loss if self.use_hinge_gan_loss: - hinge_gan_loss = 0 - hinge_gan_real_loss = 0 - hinge_gan_fake_loss = 0 - if isinstance(scores_fake, list): - for score_fake, score_real in zip(scores_fake, scores_real): - total_loss, real_loss, fake_loss = self.hinge_loss(score_fake, score_real) - hinge_gan_loss += total_loss - hinge_gan_real_loss += real_loss - hinge_gan_fake_loss += fake_loss - else: - total_loss, real_loss, fake_loss = self.hinge_loss(scores_fake, scores_real) - hinge_gan_loss = total_loss - hinge_gan_real_loss = real_loss - hinge_gan_fake_loss = fake_loss - return_dict['D_hinge_gan_loss'] = hinge_gan_loss - return_dict['D_hinge_gan_real_loss'] = hinge_gan_real_loss - return_dict['D_hinge_gan_fake_loss'] = hinge_gan_fake_loss - loss += self.hinge_gan_loss_weight * hinge_gan_loss + hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(scores_fake, scores_real, self.hinge_loss) + return_dict['D_hinge_gan_loss'] = hinge_D_loss + return_dict['D_hinge_gan_real_loss'] = hinge_D_real_loss + return_dict['D_hinge_gan_fake_loss'] = hinge_D_fake_loss + loss += hinge_D_loss return_dict['D_loss'] = loss return return_dict \ No newline at end of file