Fix the VITS GAN loss

This commit is contained in:
Edresson Casanova 2022-06-03 13:05:40 +00:00
parent e07fcc7a8c
commit d6d8d0e3e1
4 changed files with 26 additions and 14 deletions

View File

@ -118,9 +118,9 @@ class VitsConfig(BaseTTSConfig):
speaker_classifier_loss_alpha: float = 2.0 speaker_classifier_loss_alpha: float = 2.0
emotion_classifier_loss_alpha: float = 4.0 emotion_classifier_loss_alpha: float = 4.0
prosody_encoder_kl_loss_alpha: float = 5.0 prosody_encoder_kl_loss_alpha: float = 5.0
disc_latent_loss_alpha: float = 1.5 disc_latent_loss_alpha: float = 5.0
gen_latent_loss_alpha: float = 1.5 gen_latent_loss_alpha: float = 5.0
feat_latent_loss_alpha: float = 1.5 feat_latent_loss_alpha: float = 108.0
# data loader params # data loader params
return_wav: bool = True return_wav: bool = True

View File

@ -9,7 +9,13 @@ from torch.nn import functional
from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.ssim import ssim from TTS.tts.utils.ssim import ssim
from TTS.utils.audio import TorchSTFT from TTS.utils.audio import TorchSTFT
from TTS.vocoder.layers.losses import (
MelganFeatureLoss,
MSEDLoss,
MSEGLoss,
_apply_D_loss,
_apply_G_adv_loss,
)
# pylint: disable=abstract-method # pylint: disable=abstract-method
# relates https://github.com/pytorch/pytorch/issues/42305 # relates https://github.com/pytorch/pytorch/issues/42305
@ -607,6 +613,9 @@ class VitsGeneratorLoss(nn.Module):
use_mel=True, use_mel=True,
do_amp_to_db=True, do_amp_to_db=True,
) )
if c.model_args.use_latent_discriminator:
self.latent_feat_match_loss = MelganFeatureLoss()
self.gen_latent_gan_loss = MSEGLoss()
@staticmethod @staticmethod
def feature_loss(feats_real, feats_generated): def feature_loss(feats_real, feats_generated):
@ -735,13 +744,11 @@ class VitsGeneratorLoss(nn.Module):
if scores_disc_mp is not None and feats_disc_mp is not None and feats_disc_zp is not None: if scores_disc_mp is not None and feats_disc_mp is not None and feats_disc_zp is not None:
# feature loss # feature loss
loss_feat_latent = ( loss_feat_latent = self.latent_feat_match_loss(feats_disc_mp, feats_disc_zp) * self.feat_latent_loss_alpha
self.feature_loss(feats_real=feats_disc_zp, feats_generated=feats_disc_mp) * self.feat_latent_loss_alpha
)
return_dict["loss_feat_latent"] = loss_feat_latent return_dict["loss_feat_latent"] = loss_feat_latent
loss += return_dict["loss_feat_latent"] loss += return_dict["loss_feat_latent"]
# gen loss # gen loss
loss_gen_latent = self.generator_loss(scores_fake=scores_disc_mp)[0] * self.gen_latent_loss_alpha loss_gen_latent = _apply_G_adv_loss(scores_disc_mp, self.gen_latent_gan_loss) * self.gen_latent_loss_alpha
return_dict["loss_gen_latent"] = loss_gen_latent return_dict["loss_gen_latent"] = loss_gen_latent
loss += return_dict["loss_gen_latent"] loss += return_dict["loss_gen_latent"]
@ -801,7 +808,9 @@ class VitsDiscriminatorLoss(nn.Module):
def __init__(self, c: Coqpit): def __init__(self, c: Coqpit):
super().__init__() super().__init__()
self.disc_loss_alpha = c.disc_loss_alpha self.disc_loss_alpha = c.disc_loss_alpha
if c.model_args.use_latent_discriminator:
self.disc_latent_loss_alpha = c.disc_latent_loss_alpha self.disc_latent_loss_alpha = c.disc_latent_loss_alpha
self.disc_latent_gan_loss = MSEDLoss()
@staticmethod @staticmethod
@ -833,9 +842,11 @@ class VitsDiscriminatorLoss(nn.Module):
# latent discriminator # latent discriminator
if scores_disc_zp is not None and scores_disc_mp is not None: if scores_disc_zp is not None and scores_disc_mp is not None:
loss_disc_latent, _, _ = self.discriminator_loss( loss_disc_latent, loss_disc_latent_zp, loss_disc_latent_mp = _apply_D_loss(
scores_real=scores_disc_zp, scores_fake=scores_disc_mp scores_fake=scores_disc_mp, scores_real=scores_disc_zp, loss_func=self.disc_latent_gan_loss
) )
return_dict["loss_disc_latent_mp"] = loss_disc_latent_mp
return_dict["loss_disc_latent_zp"] = loss_disc_latent_zp
return_dict["loss_disc_latent"] = loss_disc_latent * self.disc_latent_loss_alpha return_dict["loss_disc_latent"] = loss_disc_latent * self.disc_latent_loss_alpha
return_dict["loss"] += return_dict["loss_disc_latent"] return_dict["loss"] += return_dict["loss_disc_latent"]

View File

@ -36,6 +36,7 @@ from TTS.tts.utils.visual import plot_alignment
from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results from TTS.vocoder.utils.generic_utils import plot_results
############################## ##############################
# IO / Feature extraction # IO / Feature extraction
############################## ##############################

View File

@ -12,8 +12,8 @@ output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = VitsConfig( config = VitsConfig(
batch_size=3, batch_size=2,
eval_batch_size=3, eval_batch_size=2,
num_loader_workers=0, num_loader_workers=0,
num_eval_loader_workers=0, num_eval_loader_workers=0,
text_cleaner="english_cleaners", text_cleaner="english_cleaners",
@ -52,7 +52,7 @@ config.model_args.use_prosody_encoder_z_p_input = True
config.model_args.prosody_encoder_type = "vae" config.model_args.prosody_encoder_type = "vae"
config.model_args.detach_prosody_enc_input = True config.model_args.detach_prosody_enc_input = True
config.model_args.use_latent_discriminator = True config.model_args.use_latent_discriminator = False
# enable end2end loss # enable end2end loss
config.model_args.use_end2end_loss = False config.model_args.use_end2end_loss = False