mirror of https://github.com/coqui-ai/TTS.git
Add emotion embedding in the encoder
This commit is contained in:
parent
1fdef1c4c9
commit
6126e5e588
|
@ -38,6 +38,7 @@ class TextEncoder(nn.Module):
|
||||||
kernel_size: int,
|
kernel_size: int,
|
||||||
dropout_p: float,
|
dropout_p: float,
|
||||||
language_emb_dim: int = None,
|
language_emb_dim: int = None,
|
||||||
|
emotion_emb_dim: int = None,
|
||||||
):
|
):
|
||||||
"""Text Encoder for VITS model.
|
"""Text Encoder for VITS model.
|
||||||
|
|
||||||
|
@ -62,6 +63,9 @@ class TextEncoder(nn.Module):
|
||||||
if language_emb_dim:
|
if language_emb_dim:
|
||||||
hidden_channels += language_emb_dim
|
hidden_channels += language_emb_dim
|
||||||
|
|
||||||
|
if emotion_emb_dim:
|
||||||
|
hidden_channels += emotion_emb_dim
|
||||||
|
|
||||||
self.encoder = RelativePositionTransformer(
|
self.encoder = RelativePositionTransformer(
|
||||||
in_channels=hidden_channels,
|
in_channels=hidden_channels,
|
||||||
out_channels=hidden_channels,
|
out_channels=hidden_channels,
|
||||||
|
@ -77,7 +81,7 @@ class TextEncoder(nn.Module):
|
||||||
|
|
||||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
def forward(self, x, x_lengths, lang_emb=None):
|
def forward(self, x, x_lengths, lang_emb=None, emo_emb=None):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- x: :math:`[B, T]`
|
- x: :math:`[B, T]`
|
||||||
|
@ -90,6 +94,10 @@ class TextEncoder(nn.Module):
|
||||||
if lang_emb is not None:
|
if lang_emb is not None:
|
||||||
x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
|
x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
|
||||||
|
|
||||||
|
# concat the emotion emb in embedding chars
|
||||||
|
if emo_emb is not None:
|
||||||
|
x = torch.cat((x, emo_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
|
||||||
|
|
||||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t]
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t]
|
||||||
|
|
||||||
|
|
|
@ -623,6 +623,7 @@ class Vits(BaseTTS):
|
||||||
self.args.kernel_size_text_encoder,
|
self.args.kernel_size_text_encoder,
|
||||||
self.args.dropout_p_text_encoder,
|
self.args.dropout_p_text_encoder,
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim,
|
||||||
|
emotion_emb_dim=self.args.emotion_embedding_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.posterior_encoder = PosteriorEncoder(
|
self.posterior_encoder = PosteriorEncoder(
|
||||||
|
@ -646,7 +647,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
if self.args.use_sdp:
|
if self.args.use_sdp:
|
||||||
self.duration_predictor = StochasticDurationPredictor(
|
self.duration_predictor = StochasticDurationPredictor(
|
||||||
self.args.hidden_channels,
|
self.args.hidden_channels + self.args.emotion_embedding_dim,
|
||||||
192,
|
192,
|
||||||
3,
|
3,
|
||||||
self.args.dropout_p_duration_predictor,
|
self.args.dropout_p_duration_predictor,
|
||||||
|
@ -656,7 +657,7 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.duration_predictor = DurationPredictor(
|
self.duration_predictor = DurationPredictor(
|
||||||
self.args.hidden_channels,
|
self.args.hidden_channels + self.args.emotion_embedding_dim,
|
||||||
256,
|
256,
|
||||||
3,
|
3,
|
||||||
self.args.dropout_p_duration_predictor,
|
self.args.dropout_p_duration_predictor,
|
||||||
|
@ -1049,7 +1050,7 @@ class Vits(BaseTTS):
|
||||||
if self.args.use_language_embedding and lid is not None:
|
if self.args.use_language_embedding and lid is not None:
|
||||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg)
|
||||||
|
|
||||||
# posterior encoder
|
# posterior encoder
|
||||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||||
|
@ -1179,7 +1180,7 @@ class Vits(BaseTTS):
|
||||||
if self.args.use_language_embedding and lid is not None:
|
if self.args.use_language_embedding and lid is not None:
|
||||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg)
|
||||||
|
|
||||||
if self.args.use_sdp:
|
if self.args.use_sdp:
|
||||||
logw = self.duration_predictor(
|
logw = self.duration_predictor(
|
||||||
|
@ -1816,9 +1817,8 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
if config.model_args.encoder_model_path and speaker_manager is not None:
|
if config.model_args.encoder_model_path and speaker_manager is not None:
|
||||||
speaker_manager.init_encoder(config.model_args.encoder_model_path, config.model_args.encoder_config_path)
|
speaker_manager.init_encoder(config.model_args.encoder_model_path, config.model_args.encoder_config_path)
|
||||||
elif config.model_args.encoder_model_path and emotion_manager is not None:
|
if config.model_args.encoder_model_path and emotion_manager is not None:
|
||||||
emotion_manager.init_encoder(config.model_args.encoder_model_path, config.model_args.encoder_config_path)
|
emotion_manager.init_encoder(config.model_args.encoder_model_path, config.model_args.encoder_config_path)
|
||||||
|
|
||||||
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager)
|
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue