Add VAE prosody encoder

This commit is contained in:
Edresson Casanova 2022-05-27 12:53:56 -03:00
parent 312789edbf
commit a2aecea8f3
5 changed files with 63 additions and 11 deletions

View File

@ -117,6 +117,7 @@ class VitsConfig(BaseTTSConfig):
consistency_loss_alpha: float = 1.0
speaker_classifier_loss_alpha: float = 2.0
emotion_classifier_loss_alpha: float = 4.0
prosody_encoder_kl_loss_alpha: float = 5.0
# data loader params
return_wav: bool = True

View File

@ -592,6 +592,7 @@ class VitsGeneratorLoss(nn.Module):
self.consistency_loss_alpha = c.consistency_loss_alpha
self.emotion_classifier_alpha = c.emotion_classifier_loss_alpha
self.speaker_classifier_alpha = c.speaker_classifier_loss_alpha
self.prosody_encoder_kl_loss_alpha = c.prosody_encoder_kl_loss_alpha
self.stft = TorchSTFT(
c.audio.fft_size,
@ -665,6 +666,7 @@ class VitsGeneratorLoss(nn.Module):
use_encoder_consistency_loss=False,
gt_cons_emb=None,
syn_cons_emb=None,
vae_outputs=None,
loss_prosody_enc_spk_rev_classifier=None,
loss_prosody_enc_emo_classifier=None,
loss_text_enc_spk_rev_classifier=None,
@ -725,7 +727,16 @@ class VitsGeneratorLoss(nn.Module):
loss += loss_text_enc_emo_classifier
return_dict["loss_text_enc_emo_classifier"] = loss_text_enc_emo_classifier
if vae_outputs is not None:
posterior_distribution, prior_distribution = vae_outputs
# KL divergence term between the posterior and the prior
kl_term = torch.mean(torch.distributions.kl_divergence(posterior_distribution, prior_distribution))
# multiply the loss by the alpha
kl_vae_loss = kl_term * self.prosody_encoder_kl_loss_alpha
loss += kl_vae_loss
return_dict["loss_kl_vae"] = kl_vae_loss
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl

View File

@ -0,0 +1,19 @@
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE
class VitsGST(GST):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, inputs, input_lengths=None, speaker_embedding=None):
style_embed = super().forward(inputs, speaker_embedding=speaker_embedding)
return style_embed, None
class VitsVAE(CapacitronVAE):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.beta = None
def forward(self, inputs, input_lengths=None):
VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths])
return VAE_embedding.to(inputs.device), [posterior_distribution, prior_distribution]

View File

@ -19,10 +19,11 @@ from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
from TTS.tts.layers.generic.classifier import ReversalClassifier
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.emotions import EmotionManager
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
@ -545,6 +546,7 @@ class VitsArgs(Coqpit):
# prosody encoder
use_prosody_encoder: bool = False
prosody_encoder_type: str = "gst"
prosody_embedding_dim: int = 0
prosody_encoder_num_heads: int = 1
prosody_encoder_num_tokens: int = 5
@ -698,11 +700,21 @@ class Vits(BaseTTS):
)
if self.args.use_prosody_encoder:
self.prosody_encoder = GST(
num_mel=self.args.hidden_channels,
num_heads=self.args.prosody_encoder_num_heads,
num_style_tokens=self.args.prosody_encoder_num_tokens,
gst_embedding_dim=self.args.prosody_embedding_dim,
if self.args.prosody_encoder_type == 'gst':
self.prosody_encoder = VitsGST(
num_mel=self.args.hidden_channels,
num_heads=self.args.prosody_encoder_num_heads,
num_style_tokens=self.args.prosody_encoder_num_tokens,
gst_embedding_dim=self.args.prosody_embedding_dim,
)
elif self.args.prosody_encoder_type == 'vae':
self.prosody_encoder = VitsVAE(
num_mel=self.args.hidden_channels,
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
)
else:
raise RuntimeError(
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
)
if self.args.use_prosody_enc_spk_reversal_classifier:
self.speaker_reversal_classifier = ReversalClassifier(
@ -1142,9 +1154,11 @@ class Vits(BaseTTS):
l_pros_emotion = None
if self.args.use_prosody_encoder:
if not self.args.use_prosody_encoder_z_p_input:
pros_emb = self.prosody_encoder(z).transpose(1, 2)
pros_emb, vae_outputs = self.prosody_encoder(z, y_lengths)
else:
pros_emb = self.prosody_encoder(z_p).transpose(1, 2)
pros_emb, vae_outputs = self.prosody_encoder(z_p, y_lengths)
pros_emb = pros_emb.transpose(1, 2)
if self.args.use_prosody_enc_spk_reversal_classifier:
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
@ -1253,6 +1267,7 @@ class Vits(BaseTTS):
"gt_cons_emb": gt_cons_emb,
"syn_cons_emb": syn_cons_emb,
"slice_ids": slice_ids,
"vae_outputs": vae_outputs,
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
"loss_prosody_enc_emo_classifier": l_pros_emotion,
"loss_text_enc_spk_rev_classifier": l_text_speaker,
@ -1322,10 +1337,12 @@ class Vits(BaseTTS):
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g)
if not self.args.use_prosody_encoder_z_p_input:
pros_emb = self.prosody_encoder(z_pro).transpose(1, 2)
pros_emb, vae_outputs = self.prosody_encoder(z_pro, pf_lengths)
else:
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g)
pros_emb = self.prosody_encoder(z_p_inf).transpose(1, 2)
pros_emb, vae_outputs = self.prosody_encoder(z_p_inf, pf_lengths)
pros_emb = pros_emb.transpose(1, 2)
x, m_p, logs_p, x_mask = self.text_encoder(
x,
@ -1469,6 +1486,7 @@ class Vits(BaseTTS):
else:
raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")
# emotion embedding
ge_src, ge_tgt = None, None
if self.args.use_emotion_embedding and ref_emotion is not None and target_emotion is not None:
ge_src = self.emb_g(ref_emotion).unsqueeze(-1)
ge_tgt = self.emb_g(target_emotion).unsqueeze(-1)
@ -1602,6 +1620,7 @@ class Vits(BaseTTS):
or self.args.use_emotion_encoder_as_loss,
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
vae_outputs=self.model_outputs_cache["vae_outputs"],
loss_prosody_enc_spk_rev_classifier=self.model_outputs_cache["loss_prosody_enc_spk_rev_classifier"],
loss_prosody_enc_emo_classifier=self.model_outputs_cache["loss_prosody_enc_emo_classifier"],
loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"],

View File

@ -49,6 +49,8 @@ config.model_args.use_prosody_enc_emo_classifier = False
config.model_args.use_text_enc_emo_classifier = True
config.model_args.use_prosody_encoder_z_p_input = True
config.model_args.prosody_encoder_type = "vae"
config.save_json(config_path)
# train the model for one epoch