mirror of https://github.com/coqui-ai/TTS.git
correct loss normalization and function refactoring
This commit is contained in:
parent
34eacb6383
commit
6b1de26869
|
@ -6,6 +6,7 @@ from torch.nn import functional as F
|
||||||
|
|
||||||
class TorchSTFT():
|
class TorchSTFT():
|
||||||
def __init__(self, n_fft, hop_length, win_length, window='hann_window'):
|
def __init__(self, n_fft, hop_length, win_length, window='hann_window'):
|
||||||
|
""" Torch based STFT operation """
|
||||||
self.n_fft = n_fft
|
self.n_fft = n_fft
|
||||||
self.hop_length = hop_length
|
self.hop_length = hop_length
|
||||||
self.win_length = win_length
|
self.win_length = win_length
|
||||||
|
@ -33,6 +34,7 @@ class TorchSTFT():
|
||||||
|
|
||||||
|
|
||||||
class STFTLoss(nn.Module):
|
class STFTLoss(nn.Module):
|
||||||
|
""" Single scale STFT Loss """
|
||||||
def __init__(self, n_fft, hop_length, win_length):
|
def __init__(self, n_fft, hop_length, win_length):
|
||||||
super(STFTLoss, self).__init__()
|
super(STFTLoss, self).__init__()
|
||||||
self.n_fft = n_fft
|
self.n_fft = n_fft
|
||||||
|
@ -50,6 +52,7 @@ class STFTLoss(nn.Module):
|
||||||
return loss_mag, loss_sc
|
return loss_mag, loss_sc
|
||||||
|
|
||||||
class MultiScaleSTFTLoss(torch.nn.Module):
|
class MultiScaleSTFTLoss(torch.nn.Module):
|
||||||
|
""" Multi scale STFT loss """
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
n_ffts=(1024, 2048, 512),
|
n_ffts=(1024, 2048, 512),
|
||||||
hop_lengths=(120, 240, 50),
|
hop_lengths=(120, 240, 50),
|
||||||
|
@ -73,6 +76,7 @@ class MultiScaleSTFTLoss(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
||||||
|
""" Multiscale STFT loss for multi band model outputs """
|
||||||
def forward(self, y_hat, y):
|
def forward(self, y_hat, y):
|
||||||
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
|
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
|
||||||
y = y.view(-1, 1, y.shape[2])
|
y = y.view(-1, 1, y.shape[2])
|
||||||
|
@ -121,16 +125,62 @@ class MelganFeatureLoss(nn.Module):
|
||||||
loss_feats = 0
|
loss_feats = 0
|
||||||
for fake_feat, real_feat in zip(fake_feats, real_feats):
|
for fake_feat, real_feat in zip(fake_feats, real_feats):
|
||||||
loss_feats += torch.mean(torch.abs(fake_feat - real_feat))
|
loss_feats += torch.mean(torch.abs(fake_feat - real_feat))
|
||||||
|
loss_feats /= len(fake_feats) + len(real_feats)
|
||||||
return loss_feats
|
return loss_feats
|
||||||
|
|
||||||
|
|
||||||
##################################
|
#####################################
|
||||||
# LOSS WRAPPERS
|
# LOSS WRAPPERS
|
||||||
|
#####################################
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_G_adv_loss(scores_fake, loss_func):
|
||||||
|
""" Compute G adversarial loss function
|
||||||
|
and normalize values """
|
||||||
|
adv_loss = 0
|
||||||
|
if isinstance(scores_fake, list):
|
||||||
|
for score_fake in scores_fake:
|
||||||
|
fake_loss = loss_func(score_fake)
|
||||||
|
adv_loss += fake_loss
|
||||||
|
adv_loss /= len(scores_fake)
|
||||||
|
else:
|
||||||
|
fake_loss = loss_func(scores_fake)
|
||||||
|
adv_loss = fake_loss
|
||||||
|
return adv_loss
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_D_loss(scores_fake, scores_real, loss_func):
|
||||||
|
""" Compute D loss func and normalize loss values """
|
||||||
|
loss = 0
|
||||||
|
real_loss = 0
|
||||||
|
fake_loss = 0
|
||||||
|
if isinstance(scores_fake, list):
|
||||||
|
# multi-scale loss
|
||||||
|
for score_fake, score_real in zip(scores_fake, scores_real):
|
||||||
|
total_loss, real_loss, fake_loss = loss_func(score_fake, score_real)
|
||||||
|
loss += total_loss
|
||||||
|
real_loss += real_loss
|
||||||
|
fake_loss += fake_loss
|
||||||
|
# normalize loss values with number of scales
|
||||||
|
loss /= len(scores_fake)
|
||||||
|
real_loss /= len(scores_real)
|
||||||
|
fake_loss /= len(scores_fake)
|
||||||
|
else:
|
||||||
|
# single scale loss
|
||||||
|
total_loss, real_loss, fake_loss = loss_func(scores_fake, scores_real)
|
||||||
|
loss = total_loss
|
||||||
|
return loss, real_loss, fake_loss
|
||||||
|
|
||||||
|
|
||||||
|
##################################
|
||||||
|
# MODEL LOSSES
|
||||||
##################################
|
##################################
|
||||||
|
|
||||||
|
|
||||||
class GeneratorLoss(nn.Module):
|
class GeneratorLoss(nn.Module):
|
||||||
def __init__(self, C):
|
def __init__(self, C):
|
||||||
|
""" Compute Generator Loss values depending on training
|
||||||
|
configuration """
|
||||||
super(GeneratorLoss, self).__init__()
|
super(GeneratorLoss, self).__init__()
|
||||||
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
|
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
|
||||||
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||||
|
@ -159,7 +209,8 @@ class GeneratorLoss(nn.Module):
|
||||||
self.feat_match_loss = MelganFeatureLoss()
|
self.feat_match_loss = MelganFeatureLoss()
|
||||||
|
|
||||||
def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None):
|
def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None):
|
||||||
loss = 0
|
gen_loss = 0
|
||||||
|
adv_loss = 0
|
||||||
return_dict = {}
|
return_dict = {}
|
||||||
|
|
||||||
# STFT Loss
|
# STFT Loss
|
||||||
|
@ -167,50 +218,41 @@ class GeneratorLoss(nn.Module):
|
||||||
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat.squeeze(1), y.squeeze(1))
|
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat.squeeze(1), y.squeeze(1))
|
||||||
return_dict['G_stft_loss_mg'] = stft_loss_mg
|
return_dict['G_stft_loss_mg'] = stft_loss_mg
|
||||||
return_dict['G_stft_loss_sc'] = stft_loss_sc
|
return_dict['G_stft_loss_sc'] = stft_loss_sc
|
||||||
loss += self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
|
gen_loss += self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
|
||||||
|
|
||||||
# subband STFT Loss
|
# subband STFT Loss
|
||||||
if self.use_subband_stft_loss:
|
if self.use_subband_stft_loss:
|
||||||
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
|
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
|
||||||
return_dict['G_subband_stft_loss_mg'] = subband_stft_loss_mg
|
return_dict['G_subband_stft_loss_mg'] = subband_stft_loss_mg
|
||||||
return_dict['G_subband_stft_loss_sc'] = subband_stft_loss_sc
|
return_dict['G_subband_stft_loss_sc'] = subband_stft_loss_sc
|
||||||
loss += self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)
|
gen_loss += self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)
|
||||||
|
|
||||||
# Fake Losses
|
# multiscale MSE adversarial loss
|
||||||
if self.use_mse_gan_loss and scores_fake is not None:
|
if self.use_mse_gan_loss and scores_fake is not None:
|
||||||
mse_fake_loss = 0
|
mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss)
|
||||||
if isinstance(scores_fake, list):
|
|
||||||
for score_fake in scores_fake:
|
|
||||||
fake_loss = self.mse_loss(score_fake)
|
|
||||||
mse_fake_loss += fake_loss
|
|
||||||
else:
|
|
||||||
fake_loss = self.mse_loss(scores_fake)
|
|
||||||
mse_fake_loss = fake_loss
|
|
||||||
return_dict['G_mse_fake_loss'] = mse_fake_loss
|
return_dict['G_mse_fake_loss'] = mse_fake_loss
|
||||||
loss += self.mse_gan_loss_weight * mse_fake_loss
|
adv_loss += self.mse_gan_loss_weight * mse_fake_loss
|
||||||
|
|
||||||
|
# multiscale Hinge adversarial loss
|
||||||
if self.use_hinge_gan_loss and not scores_fake is not None:
|
if self.use_hinge_gan_loss and not scores_fake is not None:
|
||||||
hinge_fake_loss = 0
|
hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss)
|
||||||
if isinstance(scores_fake, list):
|
|
||||||
for score_fake in scores_fake:
|
|
||||||
fake_loss = self.hinge_loss(score_fake)
|
|
||||||
hinge_fake_loss += fake_loss
|
|
||||||
else:
|
|
||||||
fake_loss = self.hinge_loss(scores_fake)
|
|
||||||
hinge_fake_loss = fake_loss
|
|
||||||
return_dict['G_hinge_fake_loss'] = hinge_fake_loss
|
return_dict['G_hinge_fake_loss'] = hinge_fake_loss
|
||||||
loss += self.hinge_gan_loss_weight * hinge_fake_loss
|
adv_loss += self.hinge_gan_loss_weight * hinge_fake_loss
|
||||||
|
|
||||||
# Feature Matching Loss
|
# Feature Matching Loss
|
||||||
if self.use_feat_match_loss and not feats_fake:
|
if self.use_feat_match_loss and not feats_fake:
|
||||||
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
||||||
return_dict['G_feat_match_loss'] = feat_match_loss
|
return_dict['G_feat_match_loss'] = feat_match_loss
|
||||||
loss += self.feat_match_loss_weight * feat_match_loss
|
adv_loss += self.feat_match_loss_weight * feat_match_loss
|
||||||
return_dict['G_loss'] = loss
|
return_dict['G_loss'] = gen_loss + adv_loss
|
||||||
|
return_dict['G_gen_loss'] = gen_loss
|
||||||
|
return_dict['G_adv_loss'] = adv_loss
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorLoss(nn.Module):
|
class DiscriminatorLoss(nn.Module):
|
||||||
|
""" Compute Discriminator Loss values depending on training
|
||||||
|
configuration """
|
||||||
def __init__(self, C):
|
def __init__(self, C):
|
||||||
super(DiscriminatorLoss, self).__init__()
|
super(DiscriminatorLoss, self).__init__()
|
||||||
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
|
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
|
||||||
|
@ -219,9 +261,6 @@ class DiscriminatorLoss(nn.Module):
|
||||||
self.use_mse_gan_loss = C.use_mse_gan_loss
|
self.use_mse_gan_loss = C.use_mse_gan_loss
|
||||||
self.use_hinge_gan_loss = C.use_hinge_gan_loss
|
self.use_hinge_gan_loss = C.use_hinge_gan_loss
|
||||||
|
|
||||||
self.mse_gan_loss_weight = C.mse_gan_loss_weight
|
|
||||||
self.hinge_gan_loss_weight = C.hinge_gan_loss_weight
|
|
||||||
|
|
||||||
if C.use_mse_gan_loss:
|
if C.use_mse_gan_loss:
|
||||||
self.mse_loss = MSEDLoss()
|
self.mse_loss = MSEDLoss()
|
||||||
if C.use_hinge_gan_loss:
|
if C.use_hinge_gan_loss:
|
||||||
|
@ -232,44 +271,18 @@ class DiscriminatorLoss(nn.Module):
|
||||||
return_dict = {}
|
return_dict = {}
|
||||||
|
|
||||||
if self.use_mse_gan_loss:
|
if self.use_mse_gan_loss:
|
||||||
mse_gan_loss = 0
|
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(scores_fake, scores_real, self.mse_loss)
|
||||||
mse_gan_real_loss = 0
|
return_dict['D_mse_gan_loss'] = mse_D_loss
|
||||||
mse_gan_fake_loss = 0
|
return_dict['D_mse_gan_real_loss'] = mse_D_real_loss
|
||||||
if isinstance(scores_fake, list):
|
return_dict['D_mse_gan_fake_loss'] = mse_D_fake_loss
|
||||||
for score_fake, score_real in zip(scores_fake, scores_real):
|
loss += mse_D_loss
|
||||||
total_loss, real_loss, fake_loss = self.mse_loss(score_fake, score_real)
|
|
||||||
mse_gan_loss += total_loss
|
|
||||||
mse_gan_real_loss += real_loss
|
|
||||||
mse_gan_fake_loss += fake_loss
|
|
||||||
else:
|
|
||||||
total_loss, real_loss, fake_loss = self.mse_loss(scores_fake, scores_real)
|
|
||||||
mse_gan_loss = total_loss
|
|
||||||
mse_gan_real_loss = real_loss
|
|
||||||
mse_gan_fake_loss = fake_loss
|
|
||||||
return_dict['D_mse_gan_loss'] = mse_gan_loss
|
|
||||||
return_dict['D_mse_gan_real_loss'] = mse_gan_real_loss
|
|
||||||
return_dict['D_mse_gan_fake_loss'] = mse_gan_fake_loss
|
|
||||||
loss += self.mse_gan_loss_weight * mse_gan_loss
|
|
||||||
|
|
||||||
if self.use_hinge_gan_loss:
|
if self.use_hinge_gan_loss:
|
||||||
hinge_gan_loss = 0
|
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(scores_fake, scores_real, self.hinge_loss)
|
||||||
hinge_gan_real_loss = 0
|
return_dict['D_hinge_gan_loss'] = hinge_D_loss
|
||||||
hinge_gan_fake_loss = 0
|
return_dict['D_hinge_gan_real_loss'] = hinge_D_real_loss
|
||||||
if isinstance(scores_fake, list):
|
return_dict['D_hinge_gan_fake_loss'] = hinge_D_fake_loss
|
||||||
for score_fake, score_real in zip(scores_fake, scores_real):
|
loss += hinge_D_loss
|
||||||
total_loss, real_loss, fake_loss = self.hinge_loss(score_fake, score_real)
|
|
||||||
hinge_gan_loss += total_loss
|
|
||||||
hinge_gan_real_loss += real_loss
|
|
||||||
hinge_gan_fake_loss += fake_loss
|
|
||||||
else:
|
|
||||||
total_loss, real_loss, fake_loss = self.hinge_loss(scores_fake, scores_real)
|
|
||||||
hinge_gan_loss = total_loss
|
|
||||||
hinge_gan_real_loss = real_loss
|
|
||||||
hinge_gan_fake_loss = fake_loss
|
|
||||||
return_dict['D_hinge_gan_loss'] = hinge_gan_loss
|
|
||||||
return_dict['D_hinge_gan_real_loss'] = hinge_gan_real_loss
|
|
||||||
return_dict['D_hinge_gan_fake_loss'] = hinge_gan_fake_loss
|
|
||||||
loss += self.hinge_gan_loss_weight * hinge_gan_loss
|
|
||||||
|
|
||||||
return_dict['D_loss'] = loss
|
return_dict['D_loss'] = loss
|
||||||
return return_dict
|
return return_dict
|
Loading…
Reference in New Issue