From acb367be263e40ebbc6771ed7ed3ac37fd4f92ae Mon Sep 17 00:00:00 2001 From: erogol Date: Fri, 5 Jun 2020 13:23:11 +0200 Subject: [PATCH] bug fix vocoder training --- vocoder/layers/losses.py | 4 ++-- vocoder/train.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vocoder/layers/losses.py b/vocoder/layers/losses.py index 22e2fb54..a24b129c 100644 --- a/vocoder/layers/losses.py +++ b/vocoder/layers/losses.py @@ -199,8 +199,8 @@ class GeneratorLoss(nn.Module): self.stft_loss_weight = C.stft_loss_weight self.subband_stft_loss_weight = C.subband_stft_loss_weight - self.mse_gan_loss_weight = C.mse_gan_loss_weight - self.hinge_gan_loss_weight = C.hinge_gan_loss_weight + self.mse_gan_loss_weight = C.mse_G_loss_weight + self.hinge_gan_loss_weight = C.hinge_G_loss_weight self.feat_match_loss_weight = C.feat_match_loss_weight if C.use_stft_loss: diff --git a/vocoder/train.py b/vocoder/train.py index a75c6a12..d1e6befe 100644 --- a/vocoder/train.py +++ b/vocoder/train.py @@ -173,7 +173,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, loss_dict = dict() for key, value in loss_G_dict.items(): - loss_dict[key] = value.item() + if isinstance(value, int): + loss_dict[key] = value + else: + loss_dict[key] = value.item() ############################## # DISCRIMINATOR