Condition the prosody encoder on z_p

This commit is contained in:
Edresson Casanova 2022-05-26 15:41:24 -03:00
parent 512525cc39
commit f774cf0648
3 changed files with 20 additions and 10 deletions

View File

@ -548,6 +548,7 @@ class VitsArgs(Coqpit):
prosody_embedding_dim: int = 0
prosody_encoder_num_heads: int = 1
prosody_encoder_num_tokens: int = 5
use_prosody_encoder_z_p_input: bool = False
use_prosody_enc_spk_reversal_classifier: bool = False
use_prosody_enc_emo_classifier: bool = False
@ -1132,17 +1133,24 @@ class Vits(BaseTTS):
# posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
# flow layers
z_p = self.flow(z, y_mask, g=g)
# prosody embedding
pros_emb = None
l_pros_speaker = None
l_pros_emotion = None
if self.args.use_prosody_encoder:
pros_emb = self.prosody_encoder(z).transpose(1, 2)
if not self.args.use_prosody_encoder_z_p_input:
pros_emb = self.prosody_encoder(z).transpose(1, 2)
else:
pros_emb = self.prosody_encoder(z_p).transpose(1, 2)
if self.args.use_prosody_enc_spk_reversal_classifier:
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
if self.args.use_prosody_enc_emo_classifier:
_, l_pros_emotion = self.pros_enc_emotion_classifier(pros_emb.transpose(1, 2), eid, x_mask=None)
x, m_p, logs_p, x_mask = self.text_encoder(
x,
x_lengths,
@ -1156,9 +1164,6 @@ class Vits(BaseTTS):
if self.args.use_text_enc_spk_reversal_classifier:
_, l_text_speaker = self.speaker_text_enc_reversal_classifier(x.transpose(1, 2), sid, x_mask=None)
# flow layers
z_p = self.flow(z, y_mask, g=g)
# reversal speaker loss to force the encoder to be speaker identity free
l_text_emotion = None
if self.args.use_text_enc_emo_classifier:
@ -1315,8 +1320,12 @@ class Vits(BaseTTS):
if self.args.use_prosody_encoder:
# extract posterior encoder feature
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
z_pro, _, _, _ = self.posterior_encoder(pf, pf_lengths, g=g)
pros_emb = self.prosody_encoder(z_pro).transpose(1, 2)
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g)
if not self.args.use_prosody_encoder_z_p_input:
pros_emb = self.prosody_encoder(z_pro).transpose(1, 2)
else:
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g)
pros_emb = self.prosody_encoder(z_p_inf).transpose(1, 2)
x, m_p, logs_p, x_mask = self.text_encoder(
x,

View File

@ -94,7 +94,7 @@ class EmotionManager(EmbeddingManager):
EmotionEncoder: Emotion encoder object.
"""
emotion_manager = None
if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False):
if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False):
if get_from_config_or_model_args_with_default(config, "emotions_ids_file", None):
emotion_manager = EmotionManager(
emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None)
@ -106,7 +106,7 @@ class EmotionManager(EmbeddingManager):
)
)
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False):
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False):
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
emotion_manager = EmotionManager(
embeddings_file_path=get_from_config_or_model_args_with_default(

View File

@ -45,8 +45,9 @@ config.model_args.use_prosody_encoder = True
config.model_args.prosody_embedding_dim = 64
# active classifier
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
config.model_args.use_prosody_enc_emo_classifier = True
config.model_args.use_prosody_enc_emo_classifier = False
config.model_args.use_text_enc_emo_classifier = True
config.model_args.use_prosody_encoder_z_p_input = True
config.save_json(config_path)