mirror of https://github.com/coqui-ai/TTS.git
Add speaker embedding on prosody encoder
This commit is contained in:
parent
251e1c289d
commit
92e7391a5d
|
@ -7,6 +7,8 @@ class VitsGST(GST):
|
|||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, inputs, input_lengths=None, speaker_embedding=None):
|
||||
if speaker_embedding is not None:
|
||||
speaker_embedding = speaker_embedding.squeeze(-1)
|
||||
style_embed = super().forward(inputs, speaker_embedding=speaker_embedding)
|
||||
return style_embed, None
|
||||
|
||||
|
@ -16,8 +18,10 @@ class VitsVAE(CapacitronVAE):
|
|||
super().__init__(*args, **kwargs)
|
||||
self.beta = None
|
||||
|
||||
def forward(self, inputs, input_lengths=None):
|
||||
VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths])
|
||||
def forward(self, inputs, input_lengths=None, speaker_embedding=None):
|
||||
if speaker_embedding is not None:
|
||||
speaker_embedding = speaker_embedding.squeeze(-1)
|
||||
VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths], speaker_embedding=speaker_embedding)
|
||||
return VAE_embedding.to(inputs.device), [posterior_distribution, prior_distribution]
|
||||
|
||||
|
||||
|
@ -25,6 +29,6 @@ class ResNetProsodyEncoder(ResNetSpeakerEncoder):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, inputs, input_lengths=None):
|
||||
def forward(self, inputs, input_lengths=None, speaker_embedding=None):
|
||||
style_embed = super().forward(inputs, l2_norm=True).unsqueeze(1)
|
||||
return style_embed, None
|
|
@ -555,8 +555,11 @@ class VitsArgs(Coqpit):
|
|||
|
||||
# prosody encoder
|
||||
use_prosody_encoder: bool = False
|
||||
use_pros_enc_input_as_pros_emb: bool = False
|
||||
prosody_encoder_type: str = "gst"
|
||||
detach_prosody_enc_input: bool = False
|
||||
condition_pros_enc_on_speaker: bool = False
|
||||
|
||||
prosody_embedding_dim: int = 0
|
||||
prosody_encoder_num_heads: int = 1
|
||||
prosody_encoder_num_tokens: int = 5
|
||||
|
@ -715,27 +718,34 @@ class Vits(BaseTTS):
|
|||
)
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
if self.args.prosody_encoder_type == "gst":
|
||||
self.prosody_encoder = VitsGST(
|
||||
num_mel=self.args.hidden_channels,
|
||||
num_heads=self.args.prosody_encoder_num_heads,
|
||||
num_style_tokens=self.args.prosody_encoder_num_tokens,
|
||||
gst_embedding_dim=self.args.prosody_embedding_dim,
|
||||
)
|
||||
elif self.args.prosody_encoder_type == "vae":
|
||||
self.prosody_encoder = VitsVAE(
|
||||
num_mel=self.args.hidden_channels,
|
||||
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
|
||||
)
|
||||
elif self.args.prosody_encoder_type == "resnet":
|
||||
self.prosody_encoder = ResNetProsodyEncoder(
|
||||
input_dim=self.args.hidden_channels,
|
||||
proj_dim=self.args.prosody_embedding_dim,
|
||||
if self.args.use_pros_enc_input_as_pros_emb:
|
||||
self.prosody_embedding_squeezer = nn.Linear(
|
||||
in_features=self.args.hidden_channels, out_features=self.args.prosody_embedding_dim
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
|
||||
)
|
||||
if self.args.prosody_encoder_type == "gst":
|
||||
self.prosody_encoder = VitsGST(
|
||||
num_mel=self.args.hidden_channels,
|
||||
num_heads=self.args.prosody_encoder_num_heads,
|
||||
num_style_tokens=self.args.prosody_encoder_num_tokens,
|
||||
gst_embedding_dim=self.args.prosody_embedding_dim,
|
||||
embedded_speaker_dim=self.cond_embedding_dim if self.args.condition_pros_enc_on_speaker else None,
|
||||
)
|
||||
elif self.args.prosody_encoder_type == "vae":
|
||||
self.prosody_encoder = VitsVAE(
|
||||
num_mel=self.args.hidden_channels,
|
||||
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
|
||||
speaker_embedding_dim=self.cond_embedding_dim if self.args.condition_pros_enc_on_speaker else None,
|
||||
)
|
||||
elif self.args.prosody_encoder_type == "resnet":
|
||||
self.prosody_encoder = ResNetProsodyEncoder(
|
||||
input_dim=self.args.hidden_channels,
|
||||
proj_dim=self.args.prosody_embedding_dim,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
|
||||
)
|
||||
if self.args.use_prosody_enc_spk_reversal_classifier:
|
||||
self.speaker_reversal_classifier = ReversalClassifier(
|
||||
in_channels=self.args.prosody_embedding_dim,
|
||||
|
@ -1223,10 +1233,16 @@ class Vits(BaseTTS):
|
|||
l_pros_emotion = None
|
||||
if self.args.use_prosody_encoder:
|
||||
prosody_encoder_input = z_p if self.args.use_prosody_encoder_z_p_input else z
|
||||
pros_emb, vae_outputs = self.prosody_encoder(
|
||||
prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
|
||||
y_lengths,
|
||||
)
|
||||
if not self.args.use_pros_enc_input_as_pros_emb:
|
||||
pros_emb, vae_outputs = self.prosody_encoder(
|
||||
prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
|
||||
y_lengths,
|
||||
speaker_embedding=g if self.args.condition_pros_enc_on_speaker else None
|
||||
)
|
||||
else:
|
||||
pros_emb = prosody_encoder_input.mean(2).unsqueeze(1).detach()
|
||||
pros_emb = F.normalize(self.prosody_embedding_squeezer(pros_emb.squeeze(1))).unsqueeze(1)
|
||||
|
||||
pros_emb = pros_emb.transpose(1, 2)
|
||||
|
||||
if self.args.use_prosody_enc_spk_reversal_classifier:
|
||||
|
@ -1427,13 +1443,19 @@ class Vits(BaseTTS):
|
|||
# extract posterior encoder feature
|
||||
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
|
||||
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=ssg)
|
||||
if not self.args.use_prosody_encoder_z_p_input:
|
||||
pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths)
|
||||
if not self.args.use_pros_enc_input_as_pros_emb:
|
||||
if not self.args.use_prosody_encoder_z_p_input:
|
||||
pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths, speaker_embedding=ssg if self.args.condition_pros_enc_on_speaker else None)
|
||||
else:
|
||||
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=ssg)
|
||||
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths, speaker_embedding=ssg if self.args.condition_pros_enc_on_speaker else None)
|
||||
else:
|
||||
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=ssg)
|
||||
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths)
|
||||
prosody_encoder_input = self.flow(z_pro, z_pro_y_mask, g=ssg) if self.args.use_prosody_encoder_z_p_input else z_pro
|
||||
pros_emb = prosody_encoder_input.mean(2).unsqueeze(1)
|
||||
pros_emb = F.normalize(self.prosody_embedding_squeezer(pros_emb.squeeze(1))).unsqueeze(1)
|
||||
|
||||
pros_emb = pros_emb.transpose(1, 2)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||
x,
|
||||
x_lengths,
|
||||
|
|
|
@ -49,11 +49,17 @@ config.model_args.use_prosody_enc_emo_classifier = False
|
|||
config.model_args.use_text_enc_emo_classifier = False
|
||||
config.model_args.use_prosody_encoder_z_p_input = True
|
||||
|
||||
config.model_args.prosody_encoder_type = "resnet"
|
||||
config.model_args.prosody_encoder_type = "gst"
|
||||
config.model_args.detach_prosody_enc_input = True
|
||||
|
||||
|
||||
config.model_args.use_latent_discriminator = True
|
||||
config.model_args.use_noise_scale_predictor = False
|
||||
config.model_args.condition_pros_enc_on_speaker = True
|
||||
|
||||
config.model_args.use_pros_enc_input_as_pros_emb = True
|
||||
config.model_args.use_prosody_embedding_squeezer = True
|
||||
config.model_args.prosody_embedding_squeezer_input_dim = 192
|
||||
|
||||
# enable end2end loss
|
||||
config.model_args.use_end2end_loss = False
|
||||
|
|
Loading…
Reference in New Issue