Add text encoder adversarial loss on the VITS

This commit is contained in:
Edresson Casanova 2022-06-02 18:40:57 -03:00
parent d9452d7038
commit 7f8c12888c
5 changed files with 98 additions and 17 deletions

View File

@ -118,6 +118,9 @@ class VitsConfig(BaseTTSConfig):
speaker_classifier_loss_alpha: float = 2.0
emotion_classifier_loss_alpha: float = 4.0
prosody_encoder_kl_loss_alpha: float = 5.0
disc_latent_loss_alpha: float = 1.5
gen_latent_loss_alpha: float = 1.5
feat_latent_loss_alpha: float = 1.5
# data loader params
return_wav: bool = True

View File

@ -593,6 +593,8 @@ class VitsGeneratorLoss(nn.Module):
self.emotion_classifier_alpha = c.emotion_classifier_loss_alpha
self.speaker_classifier_alpha = c.speaker_classifier_loss_alpha
self.prosody_encoder_kl_loss_alpha = c.prosody_encoder_kl_loss_alpha
self.feat_latent_loss_alpha = c.feat_latent_loss_alpha
self.gen_latent_loss_alpha = c.gen_latent_loss_alpha
self.stft = TorchSTFT(
c.audio.fft_size,
@ -671,6 +673,9 @@ class VitsGeneratorLoss(nn.Module):
loss_prosody_enc_emo_classifier=None,
loss_text_enc_spk_rev_classifier=None,
loss_text_enc_emo_classifier=None,
scores_disc_mp=None,
feats_disc_mp=None,
feats_disc_zp=None,
end2end_info=None,
):
"""
@ -728,6 +733,18 @@ class VitsGeneratorLoss(nn.Module):
loss += loss_text_enc_emo_classifier
return_dict["loss_text_enc_emo_classifier"] = loss_text_enc_emo_classifier
if scores_disc_mp is not None and feats_disc_mp is not None and feats_disc_zp is not None:
# feature loss
loss_feat_latent = (
self.feature_loss(feats_real=feats_disc_zp, feats_generated=feats_disc_mp) * self.feat_latent_loss_alpha
)
return_dict["loss_feat_latent"] = loss_feat_latent
loss += return_dict["loss_feat_latent"]
# gen loss
loss_gen_latent = self.generator_loss(scores_fake=scores_disc_mp)[0] * self.gen_latent_loss_alpha
return_dict["loss_gen_latent"] = loss_gen_latent
loss += return_dict["loss_gen_latent"]
if vae_outputs is not None:
posterior_distribution, prior_distribution = vae_outputs
# KL divergence term between the posterior and the prior
@ -784,6 +801,8 @@ class VitsDiscriminatorLoss(nn.Module):
def __init__(self, c: Coqpit):
super().__init__()
self.disc_loss_alpha = c.disc_loss_alpha
self.disc_latent_loss_alpha = c.disc_latent_loss_alpha
@staticmethod
def discriminator_loss(scores_real, scores_fake):
@ -800,19 +819,26 @@ class VitsDiscriminatorLoss(nn.Module):
fake_losses.append(fake_loss.item())
return loss, real_losses, fake_losses
def forward(self, scores_disc_real, scores_disc_fake, end2end_info=None):
loss = 0.0
def forward(self, scores_disc_real, scores_disc_fake, scores_disc_zp=None, scores_disc_mp=None, end2end_info=None):
return_dict = {}
return_dict["loss"] = 0.0
loss_disc, loss_disc_real, _ = self.discriminator_loss(
scores_real=scores_disc_real, scores_fake=scores_disc_fake
)
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + return_dict["loss_disc"]
return_dict["loss"] = loss
return_dict["loss"] += return_dict["loss_disc"]
for i, ldr in enumerate(loss_disc_real):
return_dict[f"loss_disc_real_{i}"] = ldr
# latent discriminator
if scores_disc_zp is not None and scores_disc_mp is not None:
loss_disc_latent, _, _ = self.discriminator_loss(
scores_real=scores_disc_zp, scores_fake=scores_disc_mp
)
return_dict["loss_disc_latent"] = loss_disc_latent * self.disc_latent_loss_alpha
return_dict["loss"] += return_dict["loss_disc_latent"]
if end2end_info is not None:
loss_disc_end2end, loss_disc_real_end2end, _ = self.discriminator_loss(
scores_real=end2end_info["scores_disc_real"], scores_fake=end2end_info["scores_disc_fake"],

View File

@ -58,13 +58,16 @@ class VitsDiscriminator(nn.Module):
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
"""
def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False):
def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False, use_latent_disc=False):
super().__init__()
self.nets = nn.ModuleList()
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm))
self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods])
self.disc_latent = None
if use_latent_disc:
self.disc_latent = LatentDiscriminator(use_spectral_norm=use_spectral_norm)
def forward(self, x, x_hat=None):
def forward(self, x, x_hat=None, m_p=None, z_p=None):
"""
Args:
x (Tensor): ground truth waveform.
@ -86,4 +89,44 @@ class VitsDiscriminator(nn.Module):
x_hat_score, x_hat_feat = net(x_hat)
x_hat_scores.append(x_hat_score)
x_hat_feats.append(x_hat_feat)
return x_scores, x_feats, x_hat_scores, x_hat_feats
# variables latent disc
mp_scores, zp_scores, mp_feats, zp_feats = None, None, None, None
if self.disc_latent is not None:
if m_p is not None:
mp_scores, mp_feats = self.disc_latent(m_p.unsqueeze(1))
if z_p is not None:
zp_scores, zp_feats = self.disc_latent(z_p.unsqueeze(1))
return x_scores, x_feats, x_hat_scores, x_hat_feats, mp_scores, mp_feats, zp_scores, zp_feats
class LatentDiscriminator(nn.Module):
"""Discriminator with the same architecture as the Univnet SpecDiscriminator"""
def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
self.discriminators = nn.ModuleList(
[
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
]
)
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
def forward(self, y):
fmap = []
for _, d in enumerate(self.discriminators):
y = d(y)
y = torch.nn.functional.leaky_relu(y, 0.1)
fmap.append(y)
y = self.out(y)
fmap.append(y)
return torch.flatten(y, 1, -1), fmap

View File

@ -20,7 +20,7 @@ from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
from TTS.tts.layers.generic.classifier import ReversalClassifier
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.discriminator import VitsDiscriminator, LatentDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
@ -563,6 +563,8 @@ class VitsArgs(Coqpit):
use_end2end_loss: bool = False
use_soft_dtw: bool = False
use_latent_discriminator: bool = False
detach_dp_input: bool = True
use_language_embedding: bool = False
embedded_language_dim: int = 4
@ -789,6 +791,7 @@ class Vits(BaseTTS):
self.disc = VitsDiscriminator(
periods=self.args.periods_multi_period_discriminator,
use_spectral_norm=self.args.use_spectral_norm_disriminator,
use_latent_disc=self.args.use_latent_discriminator,
)
def init_multispeaker(self, config: Coqpit):
@ -1631,13 +1634,13 @@ class Vits(BaseTTS):
self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init
# compute scores and features
scores_disc_fake, _, scores_disc_real, _ = self.disc(
outputs["model_outputs"].detach(), outputs["waveform_seg"]
scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _= self.disc(
outputs["model_outputs"].detach(), outputs["waveform_seg"], outputs["m_p"].detach(), outputs["z_p"].detach()
)
end2end_info = None
if self.args.use_end2end_loss:
scores_disc_fake_end2end, _, scores_disc_real_end2end, _ = self.disc(
scores_disc_fake_end2end, _, scores_disc_real_end2end, _, _, _, _, _ = self.disc(
outputs["end2end_info"]["model_outputs"].detach(), self.model_outputs_cache["end2end_info"]["waveform_seg"]
)
end2end_info = {"scores_disc_real": scores_disc_real_end2end, "scores_disc_fake": scores_disc_fake_end2end}
@ -1647,6 +1650,8 @@ class Vits(BaseTTS):
loss_dict = criterion[optimizer_idx](
scores_disc_real,
scores_disc_fake,
scores_disc_zp,
scores_disc_mp,
end2end_info=end2end_info,
)
return outputs, loss_dict
@ -1678,12 +1683,12 @@ class Vits(BaseTTS):
)
# compute discriminator scores and features
scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc(
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"]
scores_disc_fake, feats_disc_fake, _, feats_disc_real, scores_disc_mp, feats_disc_mp, _, feats_disc_zp = self.disc(
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"], self.model_outputs_cache["m_p"], self.model_outputs_cache["z_p"].detach()
)
if self.args.use_end2end_loss:
scores_disc_fake_end2end, feats_disc_fake_end2end, _, feats_disc_real_end2end = self.disc(
scores_disc_fake_end2end, feats_disc_fake_end2end, _, feats_disc_real_end2end, _, _, _, _, _ = self.disc(
self.model_outputs_cache["end2end_info"]["model_outputs"], self.model_outputs_cache["end2end_info"]["waveform_seg"]
)
self.model_outputs_cache["end2end_info"]["scores_disc_fake"] = scores_disc_fake_end2end
@ -1713,6 +1718,9 @@ class Vits(BaseTTS):
loss_prosody_enc_emo_classifier=self.model_outputs_cache["loss_prosody_enc_emo_classifier"],
loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"],
loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"],
scores_disc_mp=scores_disc_mp,
feats_disc_mp=feats_disc_mp,
feats_disc_zp=feats_disc_zp,
end2end_info=self.model_outputs_cache["end2end_info"],
)

View File

@ -12,8 +12,8 @@ output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
batch_size=3,
eval_batch_size=3,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
@ -52,8 +52,9 @@ config.model_args.use_prosody_encoder_z_p_input = True
config.model_args.prosody_encoder_type = "vae"
config.model_args.detach_prosody_enc_input = True
config.model_args.use_latent_discriminator = True
# enable end2end loss
config.model_args.use_end2end_loss = True
config.model_args.use_end2end_loss = False
config.mixed_precision = False