mirror of https://github.com/coqui-ai/TTS.git
Add VAE prosody encoder
This commit is contained in:
parent
312789edbf
commit
a2aecea8f3
|
@ -117,6 +117,7 @@ class VitsConfig(BaseTTSConfig):
|
||||||
consistency_loss_alpha: float = 1.0
|
consistency_loss_alpha: float = 1.0
|
||||||
speaker_classifier_loss_alpha: float = 2.0
|
speaker_classifier_loss_alpha: float = 2.0
|
||||||
emotion_classifier_loss_alpha: float = 4.0
|
emotion_classifier_loss_alpha: float = 4.0
|
||||||
|
prosody_encoder_kl_loss_alpha: float = 5.0
|
||||||
|
|
||||||
# data loader params
|
# data loader params
|
||||||
return_wav: bool = True
|
return_wav: bool = True
|
||||||
|
|
|
@ -592,6 +592,7 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
self.consistency_loss_alpha = c.consistency_loss_alpha
|
self.consistency_loss_alpha = c.consistency_loss_alpha
|
||||||
self.emotion_classifier_alpha = c.emotion_classifier_loss_alpha
|
self.emotion_classifier_alpha = c.emotion_classifier_loss_alpha
|
||||||
self.speaker_classifier_alpha = c.speaker_classifier_loss_alpha
|
self.speaker_classifier_alpha = c.speaker_classifier_loss_alpha
|
||||||
|
self.prosody_encoder_kl_loss_alpha = c.prosody_encoder_kl_loss_alpha
|
||||||
|
|
||||||
self.stft = TorchSTFT(
|
self.stft = TorchSTFT(
|
||||||
c.audio.fft_size,
|
c.audio.fft_size,
|
||||||
|
@ -665,6 +666,7 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
use_encoder_consistency_loss=False,
|
use_encoder_consistency_loss=False,
|
||||||
gt_cons_emb=None,
|
gt_cons_emb=None,
|
||||||
syn_cons_emb=None,
|
syn_cons_emb=None,
|
||||||
|
vae_outputs=None,
|
||||||
loss_prosody_enc_spk_rev_classifier=None,
|
loss_prosody_enc_spk_rev_classifier=None,
|
||||||
loss_prosody_enc_emo_classifier=None,
|
loss_prosody_enc_emo_classifier=None,
|
||||||
loss_text_enc_spk_rev_classifier=None,
|
loss_text_enc_spk_rev_classifier=None,
|
||||||
|
@ -725,7 +727,16 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
loss += loss_text_enc_emo_classifier
|
loss += loss_text_enc_emo_classifier
|
||||||
return_dict["loss_text_enc_emo_classifier"] = loss_text_enc_emo_classifier
|
return_dict["loss_text_enc_emo_classifier"] = loss_text_enc_emo_classifier
|
||||||
|
|
||||||
|
if vae_outputs is not None:
|
||||||
|
posterior_distribution, prior_distribution = vae_outputs
|
||||||
|
# KL divergence term between the posterior and the prior
|
||||||
|
kl_term = torch.mean(torch.distributions.kl_divergence(posterior_distribution, prior_distribution))
|
||||||
|
# multiply the loss by the alpha
|
||||||
|
kl_vae_loss = kl_term * self.prosody_encoder_kl_loss_alpha
|
||||||
|
|
||||||
|
loss += kl_vae_loss
|
||||||
|
return_dict["loss_kl_vae"] = kl_vae_loss
|
||||||
|
|
||||||
# pass losses to the dict
|
# pass losses to the dict
|
||||||
return_dict["loss_gen"] = loss_gen
|
return_dict["loss_gen"] = loss_gen
|
||||||
return_dict["loss_kl"] = loss_kl
|
return_dict["loss_kl"] = loss_kl
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||||
|
from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE
|
||||||
|
|
||||||
|
class VitsGST(GST):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, inputs, input_lengths=None, speaker_embedding=None):
|
||||||
|
style_embed = super().forward(inputs, speaker_embedding=speaker_embedding)
|
||||||
|
return style_embed, None
|
||||||
|
|
||||||
|
class VitsVAE(CapacitronVAE):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.beta = None
|
||||||
|
|
||||||
|
def forward(self, inputs, input_lengths=None):
|
||||||
|
VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths])
|
||||||
|
return VAE_embedding.to(inputs.device), [posterior_distribution, prior_distribution]
|
|
@ -19,10 +19,11 @@ from TTS.tts.configs.shared_configs import CharactersConfig
|
||||||
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
||||||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
|
||||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||||
|
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.emotions import EmotionManager
|
from TTS.tts.utils.emotions import EmotionManager
|
||||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||||
|
@ -545,6 +546,7 @@ class VitsArgs(Coqpit):
|
||||||
|
|
||||||
# prosody encoder
|
# prosody encoder
|
||||||
use_prosody_encoder: bool = False
|
use_prosody_encoder: bool = False
|
||||||
|
prosody_encoder_type: str = "gst"
|
||||||
prosody_embedding_dim: int = 0
|
prosody_embedding_dim: int = 0
|
||||||
prosody_encoder_num_heads: int = 1
|
prosody_encoder_num_heads: int = 1
|
||||||
prosody_encoder_num_tokens: int = 5
|
prosody_encoder_num_tokens: int = 5
|
||||||
|
@ -698,11 +700,21 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_prosody_encoder:
|
if self.args.use_prosody_encoder:
|
||||||
self.prosody_encoder = GST(
|
if self.args.prosody_encoder_type == 'gst':
|
||||||
num_mel=self.args.hidden_channels,
|
self.prosody_encoder = VitsGST(
|
||||||
num_heads=self.args.prosody_encoder_num_heads,
|
num_mel=self.args.hidden_channels,
|
||||||
num_style_tokens=self.args.prosody_encoder_num_tokens,
|
num_heads=self.args.prosody_encoder_num_heads,
|
||||||
gst_embedding_dim=self.args.prosody_embedding_dim,
|
num_style_tokens=self.args.prosody_encoder_num_tokens,
|
||||||
|
gst_embedding_dim=self.args.prosody_embedding_dim,
|
||||||
|
)
|
||||||
|
elif self.args.prosody_encoder_type == 'vae':
|
||||||
|
self.prosody_encoder = VitsVAE(
|
||||||
|
num_mel=self.args.hidden_channels,
|
||||||
|
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
|
||||||
)
|
)
|
||||||
if self.args.use_prosody_enc_spk_reversal_classifier:
|
if self.args.use_prosody_enc_spk_reversal_classifier:
|
||||||
self.speaker_reversal_classifier = ReversalClassifier(
|
self.speaker_reversal_classifier = ReversalClassifier(
|
||||||
|
@ -1142,9 +1154,11 @@ class Vits(BaseTTS):
|
||||||
l_pros_emotion = None
|
l_pros_emotion = None
|
||||||
if self.args.use_prosody_encoder:
|
if self.args.use_prosody_encoder:
|
||||||
if not self.args.use_prosody_encoder_z_p_input:
|
if not self.args.use_prosody_encoder_z_p_input:
|
||||||
pros_emb = self.prosody_encoder(z).transpose(1, 2)
|
pros_emb, vae_outputs = self.prosody_encoder(z, y_lengths)
|
||||||
else:
|
else:
|
||||||
pros_emb = self.prosody_encoder(z_p).transpose(1, 2)
|
pros_emb, vae_outputs = self.prosody_encoder(z_p, y_lengths)
|
||||||
|
|
||||||
|
pros_emb = pros_emb.transpose(1, 2)
|
||||||
|
|
||||||
if self.args.use_prosody_enc_spk_reversal_classifier:
|
if self.args.use_prosody_enc_spk_reversal_classifier:
|
||||||
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
||||||
|
@ -1253,6 +1267,7 @@ class Vits(BaseTTS):
|
||||||
"gt_cons_emb": gt_cons_emb,
|
"gt_cons_emb": gt_cons_emb,
|
||||||
"syn_cons_emb": syn_cons_emb,
|
"syn_cons_emb": syn_cons_emb,
|
||||||
"slice_ids": slice_ids,
|
"slice_ids": slice_ids,
|
||||||
|
"vae_outputs": vae_outputs,
|
||||||
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
|
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
|
||||||
"loss_prosody_enc_emo_classifier": l_pros_emotion,
|
"loss_prosody_enc_emo_classifier": l_pros_emotion,
|
||||||
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
||||||
|
@ -1322,10 +1337,12 @@ class Vits(BaseTTS):
|
||||||
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
|
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
|
||||||
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g)
|
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g)
|
||||||
if not self.args.use_prosody_encoder_z_p_input:
|
if not self.args.use_prosody_encoder_z_p_input:
|
||||||
pros_emb = self.prosody_encoder(z_pro).transpose(1, 2)
|
pros_emb, vae_outputs = self.prosody_encoder(z_pro, pf_lengths)
|
||||||
else:
|
else:
|
||||||
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g)
|
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g)
|
||||||
pros_emb = self.prosody_encoder(z_p_inf).transpose(1, 2)
|
pros_emb, vae_outputs = self.prosody_encoder(z_p_inf, pf_lengths)
|
||||||
|
|
||||||
|
pros_emb = pros_emb.transpose(1, 2)
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||||
x,
|
x,
|
||||||
|
@ -1469,6 +1486,7 @@ class Vits(BaseTTS):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")
|
raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")
|
||||||
# emotion embedding
|
# emotion embedding
|
||||||
|
ge_src, ge_tgt = None, None
|
||||||
if self.args.use_emotion_embedding and ref_emotion is not None and target_emotion is not None:
|
if self.args.use_emotion_embedding and ref_emotion is not None and target_emotion is not None:
|
||||||
ge_src = self.emb_g(ref_emotion).unsqueeze(-1)
|
ge_src = self.emb_g(ref_emotion).unsqueeze(-1)
|
||||||
ge_tgt = self.emb_g(target_emotion).unsqueeze(-1)
|
ge_tgt = self.emb_g(target_emotion).unsqueeze(-1)
|
||||||
|
@ -1602,6 +1620,7 @@ class Vits(BaseTTS):
|
||||||
or self.args.use_emotion_encoder_as_loss,
|
or self.args.use_emotion_encoder_as_loss,
|
||||||
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
||||||
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
||||||
|
vae_outputs=self.model_outputs_cache["vae_outputs"],
|
||||||
loss_prosody_enc_spk_rev_classifier=self.model_outputs_cache["loss_prosody_enc_spk_rev_classifier"],
|
loss_prosody_enc_spk_rev_classifier=self.model_outputs_cache["loss_prosody_enc_spk_rev_classifier"],
|
||||||
loss_prosody_enc_emo_classifier=self.model_outputs_cache["loss_prosody_enc_emo_classifier"],
|
loss_prosody_enc_emo_classifier=self.model_outputs_cache["loss_prosody_enc_emo_classifier"],
|
||||||
loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"],
|
loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"],
|
||||||
|
|
|
@ -49,6 +49,8 @@ config.model_args.use_prosody_enc_emo_classifier = False
|
||||||
config.model_args.use_text_enc_emo_classifier = True
|
config.model_args.use_text_enc_emo_classifier = True
|
||||||
config.model_args.use_prosody_encoder_z_p_input = True
|
config.model_args.use_prosody_encoder_z_p_input = True
|
||||||
|
|
||||||
|
config.model_args.prosody_encoder_type = "vae"
|
||||||
|
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
|
|
||||||
# train the model for one epoch
|
# train the model for one epoch
|
||||||
|
|
Loading…
Reference in New Issue