mirror of https://github.com/coqui-ai/TTS.git
bug fix vocoder training
This commit is contained in:
parent
fd8f1ecb7d
commit
189227c741
|
@ -199,8 +199,8 @@ class GeneratorLoss(nn.Module):
|
||||||
|
|
||||||
self.stft_loss_weight = C.stft_loss_weight
|
self.stft_loss_weight = C.stft_loss_weight
|
||||||
self.subband_stft_loss_weight = C.subband_stft_loss_weight
|
self.subband_stft_loss_weight = C.subband_stft_loss_weight
|
||||||
self.mse_gan_loss_weight = C.mse_gan_loss_weight
|
self.mse_gan_loss_weight = C.mse_G_loss_weight
|
||||||
self.hinge_gan_loss_weight = C.hinge_gan_loss_weight
|
self.hinge_gan_loss_weight = C.hinge_G_loss_weight
|
||||||
self.feat_match_loss_weight = C.feat_match_loss_weight
|
self.feat_match_loss_weight = C.feat_match_loss_weight
|
||||||
|
|
||||||
if C.use_stft_loss:
|
if C.use_stft_loss:
|
||||||
|
|
|
@ -173,7 +173,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
|
|
||||||
loss_dict = dict()
|
loss_dict = dict()
|
||||||
for key, value in loss_G_dict.items():
|
for key, value in loss_G_dict.items():
|
||||||
loss_dict[key] = value.item()
|
if isinstance(value, int):
|
||||||
|
loss_dict[key] = value
|
||||||
|
else:
|
||||||
|
loss_dict[key] = value.item()
|
||||||
|
|
||||||
##############################
|
##############################
|
||||||
# DISCRIMINATOR
|
# DISCRIMINATOR
|
||||||
|
|
Loading…
Reference in New Issue