From 6126e5e58827fc64ec2350cddf8cbdeeee6e9255 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 31 Mar 2022 19:14:41 -0300 Subject: [PATCH] Add emotion embedding in the encoder --- TTS/tts/layers/vits/networks.py | 10 +++++++++- TTS/tts/models/vits.py | 12 ++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index f97b584f..e669d589 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -38,6 +38,7 @@ class TextEncoder(nn.Module): kernel_size: int, dropout_p: float, language_emb_dim: int = None, + emotion_emb_dim: int = None, ): """Text Encoder for VITS model. @@ -62,6 +63,9 @@ class TextEncoder(nn.Module): if language_emb_dim: hidden_channels += language_emb_dim + if emotion_emb_dim: + hidden_channels += emotion_emb_dim + self.encoder = RelativePositionTransformer( in_channels=hidden_channels, out_channels=hidden_channels, @@ -77,7 +81,7 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, x, x_lengths, lang_emb=None): + def forward(self, x, x_lengths, lang_emb=None, emo_emb=None): """ Shapes: - x: :math:`[B, T]` @@ -90,6 +94,10 @@ class TextEncoder(nn.Module): if lang_emb is not None: x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) + # concat the emotion emb in embedding chars + if emo_emb is not None: + x = torch.cat((x, emo_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) + x = torch.transpose(x, 1, -1) # [b, h, t] x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t] diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index ef657300..326897e4 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -623,6 +623,7 @@ class Vits(BaseTTS): self.args.kernel_size_text_encoder, self.args.dropout_p_text_encoder, language_emb_dim=self.embedded_language_dim, + emotion_emb_dim=self.args.emotion_embedding_dim, ) self.posterior_encoder = PosteriorEncoder( @@ -646,7 +647,7 @@ class Vits(BaseTTS): if self.args.use_sdp: self.duration_predictor = StochasticDurationPredictor( - self.args.hidden_channels, + self.args.hidden_channels + self.args.emotion_embedding_dim, 192, 3, self.args.dropout_p_duration_predictor, @@ -656,7 +657,7 @@ class Vits(BaseTTS): ) else: self.duration_predictor = DurationPredictor( - self.args.hidden_channels, + self.args.hidden_channels + self.args.emotion_embedding_dim, 256, 3, self.args.dropout_p_duration_predictor, @@ -1049,7 +1050,7 @@ class Vits(BaseTTS): if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) - x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg) # posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) @@ -1179,7 +1180,7 @@ class Vits(BaseTTS): if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) - x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg) if self.args.use_sdp: logw = self.duration_predictor( @@ -1816,9 +1817,8 @@ class Vits(BaseTTS): if config.model_args.encoder_model_path and speaker_manager is not None: speaker_manager.init_encoder(config.model_args.encoder_model_path, config.model_args.encoder_config_path) - elif config.model_args.encoder_model_path and emotion_manager is not None: + if config.model_args.encoder_model_path and emotion_manager is not None: emotion_manager.init_encoder(config.model_args.encoder_model_path, config.model_args.encoder_config_path) - return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager)