mirror of https://github.com/coqui-ai/TTS.git
Add text encoder adversarial loss on the VITS
This commit is contained in:
parent
4e94b46d5e
commit
e07fcc7a8c
|
@ -118,6 +118,9 @@ class VitsConfig(BaseTTSConfig):
|
|||
speaker_classifier_loss_alpha: float = 2.0
|
||||
emotion_classifier_loss_alpha: float = 4.0
|
||||
prosody_encoder_kl_loss_alpha: float = 5.0
|
||||
disc_latent_loss_alpha: float = 1.5
|
||||
gen_latent_loss_alpha: float = 1.5
|
||||
feat_latent_loss_alpha: float = 1.5
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
|
|
|
@ -593,6 +593,8 @@ class VitsGeneratorLoss(nn.Module):
|
|||
self.emotion_classifier_alpha = c.emotion_classifier_loss_alpha
|
||||
self.speaker_classifier_alpha = c.speaker_classifier_loss_alpha
|
||||
self.prosody_encoder_kl_loss_alpha = c.prosody_encoder_kl_loss_alpha
|
||||
self.feat_latent_loss_alpha = c.feat_latent_loss_alpha
|
||||
self.gen_latent_loss_alpha = c.gen_latent_loss_alpha
|
||||
|
||||
self.stft = TorchSTFT(
|
||||
c.audio.fft_size,
|
||||
|
@ -671,6 +673,9 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss_prosody_enc_emo_classifier=None,
|
||||
loss_text_enc_spk_rev_classifier=None,
|
||||
loss_text_enc_emo_classifier=None,
|
||||
scores_disc_mp=None,
|
||||
feats_disc_mp=None,
|
||||
feats_disc_zp=None,
|
||||
end2end_info=None,
|
||||
):
|
||||
"""
|
||||
|
@ -728,6 +733,18 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss += loss_text_enc_emo_classifier
|
||||
return_dict["loss_text_enc_emo_classifier"] = loss_text_enc_emo_classifier
|
||||
|
||||
if scores_disc_mp is not None and feats_disc_mp is not None and feats_disc_zp is not None:
|
||||
# feature loss
|
||||
loss_feat_latent = (
|
||||
self.feature_loss(feats_real=feats_disc_zp, feats_generated=feats_disc_mp) * self.feat_latent_loss_alpha
|
||||
)
|
||||
return_dict["loss_feat_latent"] = loss_feat_latent
|
||||
loss += return_dict["loss_feat_latent"]
|
||||
# gen loss
|
||||
loss_gen_latent = self.generator_loss(scores_fake=scores_disc_mp)[0] * self.gen_latent_loss_alpha
|
||||
return_dict["loss_gen_latent"] = loss_gen_latent
|
||||
loss += return_dict["loss_gen_latent"]
|
||||
|
||||
if vae_outputs is not None:
|
||||
posterior_distribution, prior_distribution = vae_outputs
|
||||
# KL divergence term between the posterior and the prior
|
||||
|
@ -784,6 +801,8 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
def __init__(self, c: Coqpit):
|
||||
super().__init__()
|
||||
self.disc_loss_alpha = c.disc_loss_alpha
|
||||
self.disc_latent_loss_alpha = c.disc_latent_loss_alpha
|
||||
|
||||
|
||||
@staticmethod
|
||||
def discriminator_loss(scores_real, scores_fake):
|
||||
|
@ -800,19 +819,26 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
fake_losses.append(fake_loss.item())
|
||||
return loss, real_losses, fake_losses
|
||||
|
||||
def forward(self, scores_disc_real, scores_disc_fake, end2end_info=None):
|
||||
loss = 0.0
|
||||
def forward(self, scores_disc_real, scores_disc_fake, scores_disc_zp=None, scores_disc_mp=None, end2end_info=None):
|
||||
return_dict = {}
|
||||
return_dict["loss"] = 0.0
|
||||
loss_disc, loss_disc_real, _ = self.discriminator_loss(
|
||||
scores_real=scores_disc_real, scores_fake=scores_disc_fake
|
||||
)
|
||||
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
|
||||
loss = loss + return_dict["loss_disc"]
|
||||
return_dict["loss"] = loss
|
||||
return_dict["loss"] += return_dict["loss_disc"]
|
||||
|
||||
for i, ldr in enumerate(loss_disc_real):
|
||||
return_dict[f"loss_disc_real_{i}"] = ldr
|
||||
|
||||
# latent discriminator
|
||||
if scores_disc_zp is not None and scores_disc_mp is not None:
|
||||
loss_disc_latent, _, _ = self.discriminator_loss(
|
||||
scores_real=scores_disc_zp, scores_fake=scores_disc_mp
|
||||
)
|
||||
return_dict["loss_disc_latent"] = loss_disc_latent * self.disc_latent_loss_alpha
|
||||
return_dict["loss"] += return_dict["loss_disc_latent"]
|
||||
|
||||
if end2end_info is not None:
|
||||
loss_disc_end2end, loss_disc_real_end2end, _ = self.discriminator_loss(
|
||||
scores_real=end2end_info["scores_disc_real"], scores_fake=end2end_info["scores_disc_fake"],
|
||||
|
|
|
@ -58,13 +58,16 @@ class VitsDiscriminator(nn.Module):
|
|||
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
||||
"""
|
||||
|
||||
def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False):
|
||||
def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False, use_latent_disc=False):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList()
|
||||
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm))
|
||||
self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods])
|
||||
self.disc_latent = None
|
||||
if use_latent_disc:
|
||||
self.disc_latent = LatentDiscriminator(use_spectral_norm=use_spectral_norm)
|
||||
|
||||
def forward(self, x, x_hat=None):
|
||||
def forward(self, x, x_hat=None, m_p=None, z_p=None):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): ground truth waveform.
|
||||
|
@ -86,4 +89,44 @@ class VitsDiscriminator(nn.Module):
|
|||
x_hat_score, x_hat_feat = net(x_hat)
|
||||
x_hat_scores.append(x_hat_score)
|
||||
x_hat_feats.append(x_hat_feat)
|
||||
return x_scores, x_feats, x_hat_scores, x_hat_feats
|
||||
|
||||
# variables latent disc
|
||||
mp_scores, zp_scores, mp_feats, zp_feats = None, None, None, None
|
||||
if self.disc_latent is not None:
|
||||
if m_p is not None:
|
||||
mp_scores, mp_feats = self.disc_latent(m_p.unsqueeze(1))
|
||||
if z_p is not None:
|
||||
zp_scores, zp_feats = self.disc_latent(z_p.unsqueeze(1))
|
||||
|
||||
return x_scores, x_feats, x_hat_scores, x_hat_feats, mp_scores, mp_feats, zp_scores, zp_feats
|
||||
|
||||
|
||||
class LatentDiscriminator(nn.Module):
|
||||
"""Discriminator with the same architecture as the Univnet SpecDiscriminator"""
|
||||
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
|
||||
]
|
||||
)
|
||||
|
||||
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
|
||||
|
||||
def forward(self, y):
|
||||
fmap = []
|
||||
for _, d in enumerate(self.discriminators):
|
||||
y = d(y)
|
||||
y = torch.nn.functional.leaky_relu(y, 0.1)
|
||||
fmap.append(y)
|
||||
|
||||
y = self.out(y)
|
||||
fmap.append(y)
|
||||
|
||||
return torch.flatten(y, 1, -1), fmap
|
||||
|
|
|
@ -20,7 +20,7 @@ from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
|||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator, LatentDiscriminator
|
||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
|
||||
|
@ -563,6 +563,8 @@ class VitsArgs(Coqpit):
|
|||
use_end2end_loss: bool = False
|
||||
use_soft_dtw: bool = False
|
||||
|
||||
use_latent_discriminator: bool = False
|
||||
|
||||
detach_dp_input: bool = True
|
||||
use_language_embedding: bool = False
|
||||
embedded_language_dim: int = 4
|
||||
|
@ -789,6 +791,7 @@ class Vits(BaseTTS):
|
|||
self.disc = VitsDiscriminator(
|
||||
periods=self.args.periods_multi_period_discriminator,
|
||||
use_spectral_norm=self.args.use_spectral_norm_disriminator,
|
||||
use_latent_disc=self.args.use_latent_discriminator,
|
||||
)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
|
@ -1631,13 +1634,13 @@ class Vits(BaseTTS):
|
|||
self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init
|
||||
|
||||
# compute scores and features
|
||||
scores_disc_fake, _, scores_disc_real, _ = self.disc(
|
||||
outputs["model_outputs"].detach(), outputs["waveform_seg"]
|
||||
scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _= self.disc(
|
||||
outputs["model_outputs"].detach(), outputs["waveform_seg"], outputs["m_p"].detach(), outputs["z_p"].detach()
|
||||
)
|
||||
|
||||
end2end_info = None
|
||||
if self.args.use_end2end_loss:
|
||||
scores_disc_fake_end2end, _, scores_disc_real_end2end, _ = self.disc(
|
||||
scores_disc_fake_end2end, _, scores_disc_real_end2end, _, _, _, _, _ = self.disc(
|
||||
outputs["end2end_info"]["model_outputs"].detach(), self.model_outputs_cache["end2end_info"]["waveform_seg"]
|
||||
)
|
||||
end2end_info = {"scores_disc_real": scores_disc_real_end2end, "scores_disc_fake": scores_disc_fake_end2end}
|
||||
|
@ -1647,6 +1650,8 @@ class Vits(BaseTTS):
|
|||
loss_dict = criterion[optimizer_idx](
|
||||
scores_disc_real,
|
||||
scores_disc_fake,
|
||||
scores_disc_zp,
|
||||
scores_disc_mp,
|
||||
end2end_info=end2end_info,
|
||||
)
|
||||
return outputs, loss_dict
|
||||
|
@ -1678,12 +1683,12 @@ class Vits(BaseTTS):
|
|||
)
|
||||
|
||||
# compute discriminator scores and features
|
||||
scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc(
|
||||
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"]
|
||||
scores_disc_fake, feats_disc_fake, _, feats_disc_real, scores_disc_mp, feats_disc_mp, _, feats_disc_zp = self.disc(
|
||||
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"], self.model_outputs_cache["m_p"], self.model_outputs_cache["z_p"].detach()
|
||||
)
|
||||
|
||||
if self.args.use_end2end_loss:
|
||||
scores_disc_fake_end2end, feats_disc_fake_end2end, _, feats_disc_real_end2end = self.disc(
|
||||
scores_disc_fake_end2end, feats_disc_fake_end2end, _, feats_disc_real_end2end, _, _, _, _, _ = self.disc(
|
||||
self.model_outputs_cache["end2end_info"]["model_outputs"], self.model_outputs_cache["end2end_info"]["waveform_seg"]
|
||||
)
|
||||
self.model_outputs_cache["end2end_info"]["scores_disc_fake"] = scores_disc_fake_end2end
|
||||
|
@ -1713,6 +1718,9 @@ class Vits(BaseTTS):
|
|||
loss_prosody_enc_emo_classifier=self.model_outputs_cache["loss_prosody_enc_emo_classifier"],
|
||||
loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"],
|
||||
loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"],
|
||||
scores_disc_mp=scores_disc_mp,
|
||||
feats_disc_mp=feats_disc_mp,
|
||||
feats_disc_zp=feats_disc_zp,
|
||||
end2end_info=self.model_outputs_cache["end2end_info"],
|
||||
)
|
||||
|
||||
|
|
|
@ -12,8 +12,8 @@ output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
|||
|
||||
|
||||
config = VitsConfig(
|
||||
batch_size=2,
|
||||
eval_batch_size=2,
|
||||
batch_size=3,
|
||||
eval_batch_size=3,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
|
@ -52,8 +52,9 @@ config.model_args.use_prosody_encoder_z_p_input = True
|
|||
config.model_args.prosody_encoder_type = "vae"
|
||||
config.model_args.detach_prosody_enc_input = True
|
||||
|
||||
config.model_args.use_latent_discriminator = True
|
||||
# enable end2end loss
|
||||
config.model_args.use_end2end_loss = True
|
||||
config.model_args.use_end2end_loss = False
|
||||
|
||||
config.mixed_precision = False
|
||||
|
||||
|
|
Loading…
Reference in New Issue