mirror of https://github.com/coqui-ai/TTS.git
Add VAE prosody encoder
This commit is contained in:
parent
312789edbf
commit
a2aecea8f3
|
@ -117,6 +117,7 @@ class VitsConfig(BaseTTSConfig):
|
|||
consistency_loss_alpha: float = 1.0
|
||||
speaker_classifier_loss_alpha: float = 2.0
|
||||
emotion_classifier_loss_alpha: float = 4.0
|
||||
prosody_encoder_kl_loss_alpha: float = 5.0
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
|
|
|
@ -592,6 +592,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
self.consistency_loss_alpha = c.consistency_loss_alpha
|
||||
self.emotion_classifier_alpha = c.emotion_classifier_loss_alpha
|
||||
self.speaker_classifier_alpha = c.speaker_classifier_loss_alpha
|
||||
self.prosody_encoder_kl_loss_alpha = c.prosody_encoder_kl_loss_alpha
|
||||
|
||||
self.stft = TorchSTFT(
|
||||
c.audio.fft_size,
|
||||
|
@ -665,6 +666,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
use_encoder_consistency_loss=False,
|
||||
gt_cons_emb=None,
|
||||
syn_cons_emb=None,
|
||||
vae_outputs=None,
|
||||
loss_prosody_enc_spk_rev_classifier=None,
|
||||
loss_prosody_enc_emo_classifier=None,
|
||||
loss_text_enc_spk_rev_classifier=None,
|
||||
|
@ -725,7 +727,16 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss += loss_text_enc_emo_classifier
|
||||
return_dict["loss_text_enc_emo_classifier"] = loss_text_enc_emo_classifier
|
||||
|
||||
|
||||
if vae_outputs is not None:
|
||||
posterior_distribution, prior_distribution = vae_outputs
|
||||
# KL divergence term between the posterior and the prior
|
||||
kl_term = torch.mean(torch.distributions.kl_divergence(posterior_distribution, prior_distribution))
|
||||
# multiply the loss by the alpha
|
||||
kl_vae_loss = kl_term * self.prosody_encoder_kl_loss_alpha
|
||||
|
||||
loss += kl_vae_loss
|
||||
return_dict["loss_kl_vae"] = kl_vae_loss
|
||||
|
||||
# pass losses to the dict
|
||||
return_dict["loss_gen"] = loss_gen
|
||||
return_dict["loss_kl"] = loss_kl
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE
|
||||
|
||||
class VitsGST(GST):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, inputs, input_lengths=None, speaker_embedding=None):
|
||||
style_embed = super().forward(inputs, speaker_embedding=speaker_embedding)
|
||||
return style_embed, None
|
||||
|
||||
class VitsVAE(CapacitronVAE):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.beta = None
|
||||
|
||||
def forward(self, inputs, input_lengths=None):
|
||||
VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths])
|
||||
return VAE_embedding.to(inputs.device), [posterior_distribution, prior_distribution]
|
|
@ -19,10 +19,11 @@ from TTS.tts.configs.shared_configs import CharactersConfig
|
|||
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
||||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.emotions import EmotionManager
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||
|
@ -545,6 +546,7 @@ class VitsArgs(Coqpit):
|
|||
|
||||
# prosody encoder
|
||||
use_prosody_encoder: bool = False
|
||||
prosody_encoder_type: str = "gst"
|
||||
prosody_embedding_dim: int = 0
|
||||
prosody_encoder_num_heads: int = 1
|
||||
prosody_encoder_num_tokens: int = 5
|
||||
|
@ -698,11 +700,21 @@ class Vits(BaseTTS):
|
|||
)
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
self.prosody_encoder = GST(
|
||||
num_mel=self.args.hidden_channels,
|
||||
num_heads=self.args.prosody_encoder_num_heads,
|
||||
num_style_tokens=self.args.prosody_encoder_num_tokens,
|
||||
gst_embedding_dim=self.args.prosody_embedding_dim,
|
||||
if self.args.prosody_encoder_type == 'gst':
|
||||
self.prosody_encoder = VitsGST(
|
||||
num_mel=self.args.hidden_channels,
|
||||
num_heads=self.args.prosody_encoder_num_heads,
|
||||
num_style_tokens=self.args.prosody_encoder_num_tokens,
|
||||
gst_embedding_dim=self.args.prosody_embedding_dim,
|
||||
)
|
||||
elif self.args.prosody_encoder_type == 'vae':
|
||||
self.prosody_encoder = VitsVAE(
|
||||
num_mel=self.args.hidden_channels,
|
||||
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
|
||||
)
|
||||
if self.args.use_prosody_enc_spk_reversal_classifier:
|
||||
self.speaker_reversal_classifier = ReversalClassifier(
|
||||
|
@ -1142,9 +1154,11 @@ class Vits(BaseTTS):
|
|||
l_pros_emotion = None
|
||||
if self.args.use_prosody_encoder:
|
||||
if not self.args.use_prosody_encoder_z_p_input:
|
||||
pros_emb = self.prosody_encoder(z).transpose(1, 2)
|
||||
pros_emb, vae_outputs = self.prosody_encoder(z, y_lengths)
|
||||
else:
|
||||
pros_emb = self.prosody_encoder(z_p).transpose(1, 2)
|
||||
pros_emb, vae_outputs = self.prosody_encoder(z_p, y_lengths)
|
||||
|
||||
pros_emb = pros_emb.transpose(1, 2)
|
||||
|
||||
if self.args.use_prosody_enc_spk_reversal_classifier:
|
||||
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
||||
|
@ -1253,6 +1267,7 @@ class Vits(BaseTTS):
|
|||
"gt_cons_emb": gt_cons_emb,
|
||||
"syn_cons_emb": syn_cons_emb,
|
||||
"slice_ids": slice_ids,
|
||||
"vae_outputs": vae_outputs,
|
||||
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
|
||||
"loss_prosody_enc_emo_classifier": l_pros_emotion,
|
||||
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
||||
|
@ -1322,10 +1337,12 @@ class Vits(BaseTTS):
|
|||
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
|
||||
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g)
|
||||
if not self.args.use_prosody_encoder_z_p_input:
|
||||
pros_emb = self.prosody_encoder(z_pro).transpose(1, 2)
|
||||
pros_emb, vae_outputs = self.prosody_encoder(z_pro, pf_lengths)
|
||||
else:
|
||||
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g)
|
||||
pros_emb = self.prosody_encoder(z_p_inf).transpose(1, 2)
|
||||
pros_emb, vae_outputs = self.prosody_encoder(z_p_inf, pf_lengths)
|
||||
|
||||
pros_emb = pros_emb.transpose(1, 2)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||
x,
|
||||
|
@ -1469,6 +1486,7 @@ class Vits(BaseTTS):
|
|||
else:
|
||||
raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")
|
||||
# emotion embedding
|
||||
ge_src, ge_tgt = None, None
|
||||
if self.args.use_emotion_embedding and ref_emotion is not None and target_emotion is not None:
|
||||
ge_src = self.emb_g(ref_emotion).unsqueeze(-1)
|
||||
ge_tgt = self.emb_g(target_emotion).unsqueeze(-1)
|
||||
|
@ -1602,6 +1620,7 @@ class Vits(BaseTTS):
|
|||
or self.args.use_emotion_encoder_as_loss,
|
||||
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
||||
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
||||
vae_outputs=self.model_outputs_cache["vae_outputs"],
|
||||
loss_prosody_enc_spk_rev_classifier=self.model_outputs_cache["loss_prosody_enc_spk_rev_classifier"],
|
||||
loss_prosody_enc_emo_classifier=self.model_outputs_cache["loss_prosody_enc_emo_classifier"],
|
||||
loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"],
|
||||
|
|
|
@ -49,6 +49,8 @@ config.model_args.use_prosody_enc_emo_classifier = False
|
|||
config.model_args.use_text_enc_emo_classifier = True
|
||||
config.model_args.use_prosody_encoder_z_p_input = True
|
||||
|
||||
config.model_args.prosody_encoder_type = "vae"
|
||||
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
|
|
Loading…
Reference in New Issue