mirror of https://github.com/coqui-ai/TTS.git
Fix the VITS GAN loss
This commit is contained in:
parent
7f8c12888c
commit
cbc81b55cb
|
@ -118,9 +118,9 @@ class VitsConfig(BaseTTSConfig):
|
|||
speaker_classifier_loss_alpha: float = 2.0
|
||||
emotion_classifier_loss_alpha: float = 4.0
|
||||
prosody_encoder_kl_loss_alpha: float = 5.0
|
||||
disc_latent_loss_alpha: float = 1.5
|
||||
gen_latent_loss_alpha: float = 1.5
|
||||
feat_latent_loss_alpha: float = 1.5
|
||||
disc_latent_loss_alpha: float = 5.0
|
||||
gen_latent_loss_alpha: float = 5.0
|
||||
feat_latent_loss_alpha: float = 108.0
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
|
|
|
@ -9,7 +9,13 @@ from torch.nn import functional
|
|||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.ssim import ssim
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
|
||||
from TTS.vocoder.layers.losses import (
|
||||
MelganFeatureLoss,
|
||||
MSEDLoss,
|
||||
MSEGLoss,
|
||||
_apply_D_loss,
|
||||
_apply_G_adv_loss,
|
||||
)
|
||||
|
||||
# pylint: disable=abstract-method
|
||||
# relates https://github.com/pytorch/pytorch/issues/42305
|
||||
|
@ -607,6 +613,9 @@ class VitsGeneratorLoss(nn.Module):
|
|||
use_mel=True,
|
||||
do_amp_to_db=True,
|
||||
)
|
||||
if c.model_args.use_latent_discriminator:
|
||||
self.latent_feat_match_loss = MelganFeatureLoss()
|
||||
self.gen_latent_gan_loss = MSEGLoss()
|
||||
|
||||
@staticmethod
|
||||
def feature_loss(feats_real, feats_generated):
|
||||
|
@ -735,13 +744,11 @@ class VitsGeneratorLoss(nn.Module):
|
|||
|
||||
if scores_disc_mp is not None and feats_disc_mp is not None and feats_disc_zp is not None:
|
||||
# feature loss
|
||||
loss_feat_latent = (
|
||||
self.feature_loss(feats_real=feats_disc_zp, feats_generated=feats_disc_mp) * self.feat_latent_loss_alpha
|
||||
)
|
||||
loss_feat_latent = self.latent_feat_match_loss(feats_disc_mp, feats_disc_zp) * self.feat_latent_loss_alpha
|
||||
return_dict["loss_feat_latent"] = loss_feat_latent
|
||||
loss += return_dict["loss_feat_latent"]
|
||||
# gen loss
|
||||
loss_gen_latent = self.generator_loss(scores_fake=scores_disc_mp)[0] * self.gen_latent_loss_alpha
|
||||
loss_gen_latent = _apply_G_adv_loss(scores_disc_mp, self.gen_latent_gan_loss) * self.gen_latent_loss_alpha
|
||||
return_dict["loss_gen_latent"] = loss_gen_latent
|
||||
loss += return_dict["loss_gen_latent"]
|
||||
|
||||
|
@ -801,7 +808,9 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
def __init__(self, c: Coqpit):
|
||||
super().__init__()
|
||||
self.disc_loss_alpha = c.disc_loss_alpha
|
||||
self.disc_latent_loss_alpha = c.disc_latent_loss_alpha
|
||||
if c.model_args.use_latent_discriminator:
|
||||
self.disc_latent_loss_alpha = c.disc_latent_loss_alpha
|
||||
self.disc_latent_gan_loss = MSEDLoss()
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
@ -833,9 +842,11 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
|
||||
# latent discriminator
|
||||
if scores_disc_zp is not None and scores_disc_mp is not None:
|
||||
loss_disc_latent, _, _ = self.discriminator_loss(
|
||||
scores_real=scores_disc_zp, scores_fake=scores_disc_mp
|
||||
loss_disc_latent, loss_disc_latent_zp, loss_disc_latent_mp = _apply_D_loss(
|
||||
scores_fake=scores_disc_mp, scores_real=scores_disc_zp, loss_func=self.disc_latent_gan_loss
|
||||
)
|
||||
return_dict["loss_disc_latent_mp"] = loss_disc_latent_mp
|
||||
return_dict["loss_disc_latent_zp"] = loss_disc_latent_zp
|
||||
return_dict["loss_disc_latent"] = loss_disc_latent * self.disc_latent_loss_alpha
|
||||
return_dict["loss"] += return_dict["loss_disc_latent"]
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ from TTS.tts.utils.visual import plot_alignment
|
|||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
||||
##############################
|
||||
# IO / Feature extraction
|
||||
##############################
|
||||
|
|
|
@ -12,8 +12,8 @@ output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
|||
|
||||
|
||||
config = VitsConfig(
|
||||
batch_size=3,
|
||||
eval_batch_size=3,
|
||||
batch_size=2,
|
||||
eval_batch_size=2,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
|
@ -52,7 +52,7 @@ config.model_args.use_prosody_encoder_z_p_input = True
|
|||
config.model_args.prosody_encoder_type = "vae"
|
||||
config.model_args.detach_prosody_enc_input = True
|
||||
|
||||
config.model_args.use_latent_discriminator = True
|
||||
config.model_args.use_latent_discriminator = False
|
||||
# enable end2end loss
|
||||
config.model_args.use_end2end_loss = False
|
||||
|
||||
|
|
Loading…
Reference in New Issue