Add speaker embedding on prosody encoder

This commit is contained in:
Edresson Casanova 2022-06-16 19:06:48 +00:00
parent 251e1c289d
commit 92e7391a5d
3 changed files with 63 additions and 31 deletions

View File

@ -7,6 +7,8 @@ class VitsGST(GST):
super().__init__(*args, **kwargs)
def forward(self, inputs, input_lengths=None, speaker_embedding=None):
if speaker_embedding is not None:
speaker_embedding = speaker_embedding.squeeze(-1)
style_embed = super().forward(inputs, speaker_embedding=speaker_embedding)
return style_embed, None
@ -16,8 +18,10 @@ class VitsVAE(CapacitronVAE):
super().__init__(*args, **kwargs)
self.beta = None
def forward(self, inputs, input_lengths=None):
VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths])
def forward(self, inputs, input_lengths=None, speaker_embedding=None):
if speaker_embedding is not None:
speaker_embedding = speaker_embedding.squeeze(-1)
VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths], speaker_embedding=speaker_embedding)
return VAE_embedding.to(inputs.device), [posterior_distribution, prior_distribution]
@ -25,6 +29,6 @@ class ResNetProsodyEncoder(ResNetSpeakerEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, inputs, input_lengths=None):
def forward(self, inputs, input_lengths=None, speaker_embedding=None):
style_embed = super().forward(inputs, l2_norm=True).unsqueeze(1)
return style_embed, None

View File

@ -555,8 +555,11 @@ class VitsArgs(Coqpit):
# prosody encoder
use_prosody_encoder: bool = False
use_pros_enc_input_as_pros_emb: bool = False
prosody_encoder_type: str = "gst"
detach_prosody_enc_input: bool = False
condition_pros_enc_on_speaker: bool = False
prosody_embedding_dim: int = 0
prosody_encoder_num_heads: int = 1
prosody_encoder_num_tokens: int = 5
@ -715,27 +718,34 @@ class Vits(BaseTTS):
)
if self.args.use_prosody_encoder:
if self.args.prosody_encoder_type == "gst":
self.prosody_encoder = VitsGST(
num_mel=self.args.hidden_channels,
num_heads=self.args.prosody_encoder_num_heads,
num_style_tokens=self.args.prosody_encoder_num_tokens,
gst_embedding_dim=self.args.prosody_embedding_dim,
)
elif self.args.prosody_encoder_type == "vae":
self.prosody_encoder = VitsVAE(
num_mel=self.args.hidden_channels,
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
)
elif self.args.prosody_encoder_type == "resnet":
self.prosody_encoder = ResNetProsodyEncoder(
input_dim=self.args.hidden_channels,
proj_dim=self.args.prosody_embedding_dim,
if self.args.use_pros_enc_input_as_pros_emb:
self.prosody_embedding_squeezer = nn.Linear(
in_features=self.args.hidden_channels, out_features=self.args.prosody_embedding_dim
)
else:
raise RuntimeError(
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
)
if self.args.prosody_encoder_type == "gst":
self.prosody_encoder = VitsGST(
num_mel=self.args.hidden_channels,
num_heads=self.args.prosody_encoder_num_heads,
num_style_tokens=self.args.prosody_encoder_num_tokens,
gst_embedding_dim=self.args.prosody_embedding_dim,
embedded_speaker_dim=self.cond_embedding_dim if self.args.condition_pros_enc_on_speaker else None,
)
elif self.args.prosody_encoder_type == "vae":
self.prosody_encoder = VitsVAE(
num_mel=self.args.hidden_channels,
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
speaker_embedding_dim=self.cond_embedding_dim if self.args.condition_pros_enc_on_speaker else None,
)
elif self.args.prosody_encoder_type == "resnet":
self.prosody_encoder = ResNetProsodyEncoder(
input_dim=self.args.hidden_channels,
proj_dim=self.args.prosody_embedding_dim,
)
else:
raise RuntimeError(
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
)
if self.args.use_prosody_enc_spk_reversal_classifier:
self.speaker_reversal_classifier = ReversalClassifier(
in_channels=self.args.prosody_embedding_dim,
@ -1223,10 +1233,16 @@ class Vits(BaseTTS):
l_pros_emotion = None
if self.args.use_prosody_encoder:
prosody_encoder_input = z_p if self.args.use_prosody_encoder_z_p_input else z
pros_emb, vae_outputs = self.prosody_encoder(
prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
y_lengths,
)
if not self.args.use_pros_enc_input_as_pros_emb:
pros_emb, vae_outputs = self.prosody_encoder(
prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
y_lengths,
speaker_embedding=g if self.args.condition_pros_enc_on_speaker else None
)
else:
pros_emb = prosody_encoder_input.mean(2).unsqueeze(1).detach()
pros_emb = F.normalize(self.prosody_embedding_squeezer(pros_emb.squeeze(1))).unsqueeze(1)
pros_emb = pros_emb.transpose(1, 2)
if self.args.use_prosody_enc_spk_reversal_classifier:
@ -1427,13 +1443,19 @@ class Vits(BaseTTS):
# extract posterior encoder feature
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=ssg)
if not self.args.use_prosody_encoder_z_p_input:
pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths)
if not self.args.use_pros_enc_input_as_pros_emb:
if not self.args.use_prosody_encoder_z_p_input:
pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths, speaker_embedding=ssg if self.args.condition_pros_enc_on_speaker else None)
else:
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=ssg)
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths, speaker_embedding=ssg if self.args.condition_pros_enc_on_speaker else None)
else:
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=ssg)
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths)
prosody_encoder_input = self.flow(z_pro, z_pro_y_mask, g=ssg) if self.args.use_prosody_encoder_z_p_input else z_pro
pros_emb = prosody_encoder_input.mean(2).unsqueeze(1)
pros_emb = F.normalize(self.prosody_embedding_squeezer(pros_emb.squeeze(1))).unsqueeze(1)
pros_emb = pros_emb.transpose(1, 2)
x, m_p, logs_p, x_mask = self.text_encoder(
x,
x_lengths,

View File

@ -49,11 +49,17 @@ config.model_args.use_prosody_enc_emo_classifier = False
config.model_args.use_text_enc_emo_classifier = False
config.model_args.use_prosody_encoder_z_p_input = True
config.model_args.prosody_encoder_type = "resnet"
config.model_args.prosody_encoder_type = "gst"
config.model_args.detach_prosody_enc_input = True
config.model_args.use_latent_discriminator = True
config.model_args.use_noise_scale_predictor = False
config.model_args.condition_pros_enc_on_speaker = True
config.model_args.use_pros_enc_input_as_pros_emb = True
config.model_args.use_prosody_embedding_squeezer = True
config.model_args.prosody_embedding_squeezer_input_dim = 192
# enable end2end loss
config.model_args.use_end2end_loss = False