bug fix vocoder training

This commit is contained in:
erogol 2020-06-05 13:23:11 +02:00
parent 16eb15a5ff
commit acb367be26
2 changed files with 6 additions and 3 deletions

View File

@ -199,8 +199,8 @@ class GeneratorLoss(nn.Module):
self.stft_loss_weight = C.stft_loss_weight
self.subband_stft_loss_weight = C.subband_stft_loss_weight
self.mse_gan_loss_weight = C.mse_gan_loss_weight
self.hinge_gan_loss_weight = C.hinge_gan_loss_weight
self.mse_gan_loss_weight = C.mse_G_loss_weight
self.hinge_gan_loss_weight = C.hinge_G_loss_weight
self.feat_match_loss_weight = C.feat_match_loss_weight
if C.use_stft_loss:

View File

@ -173,7 +173,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
loss_dict = dict()
for key, value in loss_G_dict.items():
loss_dict[key] = value.item()
if isinstance(value, int):
loss_dict[key] = value
else:
loss_dict[key] = value.item()
##############################
# DISCRIMINATOR