From 8f21991a8462018b667bd46349a2e7f5821c257e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 4 Apr 2022 09:44:20 +0200 Subject: [PATCH] Add cond layer in decoder --- TTS/tts/layers/feed_forward/decoder.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/TTS/tts/layers/feed_forward/decoder.py b/TTS/tts/layers/feed_forward/decoder.py index 34c586aa..70598f91 100644 --- a/TTS/tts/layers/feed_forward/decoder.py +++ b/TTS/tts/layers/feed_forward/decoder.py @@ -117,7 +117,7 @@ class FFTransformerDecoder(nn.Module): self.postnet = nn.Conv1d(in_channels, out_channels, 1) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument - # TODO: handle multi-speaker + # TODO: maybe pass g to every block x_mask = 1 if x_mask is None else x_mask o = self.transformer_block(x) * x_mask o = self.postnet(o) * x_mask @@ -191,6 +191,9 @@ class Decoder(nn.Module): ): super().__init__() + if c_in_channels and c_in_channels != 0: + self.cond = nn.Conv1d(c_in_channels, in_hidden_channels, 1) + if decoder_type.lower() == "relative_position_transformer": self.decoder = RelativePositionTransformerDecoder( in_channels=in_hidden_channels, @@ -225,6 +228,9 @@ class Decoder(nn.Module): x_mask: [B, 1, T] g: [B, C_g, 1] """ - # TODO: implement multi-speaker - o = self.decoder(x, x_mask, g) + # multi-speaker conditioning + if hasattr(self, "cond") and self.cond is not None: + g = self.cond(g) + x = x + g + o = self.decoder(x=x, x_mask=x_mask, g=g) return o