Add an option to detach the prosody encoder input

This commit is contained in:
Edresson Casanova 2022-05-27 13:19:06 -03:00
parent a2aecea8f3
commit 2568b722dd
2 changed files with 10 additions and 6 deletions

View File

@ -547,12 +547,14 @@ class VitsArgs(Coqpit):
# prosody encoder
use_prosody_encoder: bool = False
prosody_encoder_type: str = "gst"
detach_prosody_enc_input: bool = False
prosody_embedding_dim: int = 0
prosody_encoder_num_heads: int = 1
prosody_encoder_num_tokens: int = 5
use_prosody_encoder_z_p_input: bool = False
use_prosody_enc_spk_reversal_classifier: bool = False
use_prosody_enc_emo_classifier: bool = False
use_prosody_conditional_flow_module: bool = False
prosody_conditional_flow_module_on_decoder: bool = False
@ -1153,10 +1155,11 @@ class Vits(BaseTTS):
l_pros_speaker = None
l_pros_emotion = None
if self.args.use_prosody_encoder:
if not self.args.use_prosody_encoder_z_p_input:
pros_emb, vae_outputs = self.prosody_encoder(z, y_lengths)
else:
pros_emb, vae_outputs = self.prosody_encoder(z_p, y_lengths)
prosody_encoder_input = z_p if self.args.use_prosody_encoder_z_p_input else z
pros_emb, vae_outputs = self.prosody_encoder(
prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
y_lengths
)
pros_emb = pros_emb.transpose(1, 2)
@ -1337,10 +1340,10 @@ class Vits(BaseTTS):
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g)
if not self.args.use_prosody_encoder_z_p_input:
pros_emb, vae_outputs = self.prosody_encoder(z_pro, pf_lengths)
pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths)
else:
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g)
pros_emb, vae_outputs = self.prosody_encoder(z_p_inf, pf_lengths)
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths)
pros_emb = pros_emb.transpose(1, 2)

View File

@ -50,6 +50,7 @@ config.model_args.use_text_enc_emo_classifier = True
config.model_args.use_prosody_encoder_z_p_input = True
config.model_args.prosody_encoder_type = "vae"
config.model_args.detach_prosody_enc_input = True
config.save_json(config_path)