Add an option to detach the prosody encoder input

This commit is contained in:
Edresson Casanova 2022-05-27 13:19:06 -03:00
parent a2aecea8f3
commit 2568b722dd
2 changed files with 10 additions and 6 deletions

View File

@ -547,6 +547,7 @@ class VitsArgs(Coqpit):
# prosody encoder # prosody encoder
use_prosody_encoder: bool = False use_prosody_encoder: bool = False
prosody_encoder_type: str = "gst" prosody_encoder_type: str = "gst"
detach_prosody_enc_input: bool = False
prosody_embedding_dim: int = 0 prosody_embedding_dim: int = 0
prosody_encoder_num_heads: int = 1 prosody_encoder_num_heads: int = 1
prosody_encoder_num_tokens: int = 5 prosody_encoder_num_tokens: int = 5
@ -554,6 +555,7 @@ class VitsArgs(Coqpit):
use_prosody_enc_spk_reversal_classifier: bool = False use_prosody_enc_spk_reversal_classifier: bool = False
use_prosody_enc_emo_classifier: bool = False use_prosody_enc_emo_classifier: bool = False
use_prosody_conditional_flow_module: bool = False use_prosody_conditional_flow_module: bool = False
prosody_conditional_flow_module_on_decoder: bool = False prosody_conditional_flow_module_on_decoder: bool = False
@ -1153,10 +1155,11 @@ class Vits(BaseTTS):
l_pros_speaker = None l_pros_speaker = None
l_pros_emotion = None l_pros_emotion = None
if self.args.use_prosody_encoder: if self.args.use_prosody_encoder:
if not self.args.use_prosody_encoder_z_p_input: prosody_encoder_input = z_p if self.args.use_prosody_encoder_z_p_input else z
pros_emb, vae_outputs = self.prosody_encoder(z, y_lengths) pros_emb, vae_outputs = self.prosody_encoder(
else: prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
pros_emb, vae_outputs = self.prosody_encoder(z_p, y_lengths) y_lengths
)
pros_emb = pros_emb.transpose(1, 2) pros_emb = pros_emb.transpose(1, 2)
@ -1337,10 +1340,10 @@ class Vits(BaseTTS):
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device) pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g) z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g)
if not self.args.use_prosody_encoder_z_p_input: if not self.args.use_prosody_encoder_z_p_input:
pros_emb, vae_outputs = self.prosody_encoder(z_pro, pf_lengths) pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths)
else: else:
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g) z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g)
pros_emb, vae_outputs = self.prosody_encoder(z_p_inf, pf_lengths) pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths)
pros_emb = pros_emb.transpose(1, 2) pros_emb = pros_emb.transpose(1, 2)

View File

@ -50,6 +50,7 @@ config.model_args.use_text_enc_emo_classifier = True
config.model_args.use_prosody_encoder_z_p_input = True config.model_args.use_prosody_encoder_z_p_input = True
config.model_args.prosody_encoder_type = "vae" config.model_args.prosody_encoder_type = "vae"
config.model_args.detach_prosody_enc_input = True
config.save_json(config_path) config.save_json(config_path)