mirror of https://github.com/coqui-ai/TTS.git
Condition the prosody encoder on z_p
This commit is contained in:
parent
512525cc39
commit
f774cf0648
|
@ -548,6 +548,7 @@ class VitsArgs(Coqpit):
|
|||
prosody_embedding_dim: int = 0
|
||||
prosody_encoder_num_heads: int = 1
|
||||
prosody_encoder_num_tokens: int = 5
|
||||
use_prosody_encoder_z_p_input: bool = False
|
||||
use_prosody_enc_spk_reversal_classifier: bool = False
|
||||
use_prosody_enc_emo_classifier: bool = False
|
||||
|
||||
|
@ -1132,17 +1133,24 @@ class Vits(BaseTTS):
|
|||
# posterior encoder
|
||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||
|
||||
# flow layers
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
|
||||
# prosody embedding
|
||||
pros_emb = None
|
||||
l_pros_speaker = None
|
||||
l_pros_emotion = None
|
||||
if self.args.use_prosody_encoder:
|
||||
pros_emb = self.prosody_encoder(z).transpose(1, 2)
|
||||
if not self.args.use_prosody_encoder_z_p_input:
|
||||
pros_emb = self.prosody_encoder(z).transpose(1, 2)
|
||||
else:
|
||||
pros_emb = self.prosody_encoder(z_p).transpose(1, 2)
|
||||
|
||||
if self.args.use_prosody_enc_spk_reversal_classifier:
|
||||
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
||||
if self.args.use_prosody_enc_emo_classifier:
|
||||
_, l_pros_emotion = self.pros_enc_emotion_classifier(pros_emb.transpose(1, 2), eid, x_mask=None)
|
||||
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||
x,
|
||||
x_lengths,
|
||||
|
@ -1156,9 +1164,6 @@ class Vits(BaseTTS):
|
|||
if self.args.use_text_enc_spk_reversal_classifier:
|
||||
_, l_text_speaker = self.speaker_text_enc_reversal_classifier(x.transpose(1, 2), sid, x_mask=None)
|
||||
|
||||
# flow layers
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
|
||||
# reversal speaker loss to force the encoder to be speaker identity free
|
||||
l_text_emotion = None
|
||||
if self.args.use_text_enc_emo_classifier:
|
||||
|
@ -1315,8 +1320,12 @@ class Vits(BaseTTS):
|
|||
if self.args.use_prosody_encoder:
|
||||
# extract posterior encoder feature
|
||||
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
|
||||
z_pro, _, _, _ = self.posterior_encoder(pf, pf_lengths, g=g)
|
||||
pros_emb = self.prosody_encoder(z_pro).transpose(1, 2)
|
||||
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g)
|
||||
if not self.args.use_prosody_encoder_z_p_input:
|
||||
pros_emb = self.prosody_encoder(z_pro).transpose(1, 2)
|
||||
else:
|
||||
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g)
|
||||
pros_emb = self.prosody_encoder(z_p_inf).transpose(1, 2)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||
x,
|
||||
|
|
|
@ -94,7 +94,7 @@ class EmotionManager(EmbeddingManager):
|
|||
EmotionEncoder: Emotion encoder object.
|
||||
"""
|
||||
emotion_manager = None
|
||||
if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False):
|
||||
if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False):
|
||||
if get_from_config_or_model_args_with_default(config, "emotions_ids_file", None):
|
||||
emotion_manager = EmotionManager(
|
||||
emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None)
|
||||
|
@ -106,7 +106,7 @@ class EmotionManager(EmbeddingManager):
|
|||
)
|
||||
)
|
||||
|
||||
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False):
|
||||
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False):
|
||||
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
|
||||
emotion_manager = EmotionManager(
|
||||
embeddings_file_path=get_from_config_or_model_args_with_default(
|
||||
|
|
|
@ -45,8 +45,9 @@ config.model_args.use_prosody_encoder = True
|
|||
config.model_args.prosody_embedding_dim = 64
|
||||
# active classifier
|
||||
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.use_prosody_enc_emo_classifier = True
|
||||
config.model_args.use_prosody_enc_emo_classifier = False
|
||||
config.model_args.use_text_enc_emo_classifier = True
|
||||
config.model_args.use_prosody_encoder_z_p_input = True
|
||||
|
||||
config.save_json(config_path)
|
||||
|
||||
|
|
Loading…
Reference in New Issue