diff --git a/TTS/tts/layers/feed_forward/decoder.py b/TTS/tts/layers/feed_forward/decoder.py index 34c586aa..70598f91 100644 --- a/TTS/tts/layers/feed_forward/decoder.py +++ b/TTS/tts/layers/feed_forward/decoder.py @@ -117,7 +117,7 @@ class FFTransformerDecoder(nn.Module): self.postnet = nn.Conv1d(in_channels, out_channels, 1) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument - # TODO: handle multi-speaker + # TODO: maybe pass g to every block x_mask = 1 if x_mask is None else x_mask o = self.transformer_block(x) * x_mask o = self.postnet(o) * x_mask @@ -191,6 +191,9 @@ class Decoder(nn.Module): ): super().__init__() + if c_in_channels and c_in_channels != 0: + self.cond = nn.Conv1d(c_in_channels, in_hidden_channels, 1) + if decoder_type.lower() == "relative_position_transformer": self.decoder = RelativePositionTransformerDecoder( in_channels=in_hidden_channels, @@ -225,6 +228,9 @@ class Decoder(nn.Module): x_mask: [B, 1, T] g: [B, C_g, 1] """ - # TODO: implement multi-speaker - o = self.decoder(x, x_mask, g) + # multi-speaker conditioning + if hasattr(self, "cond") and self.cond is not None: + g = self.cond(g) + x = x + g + o = self.decoder(x=x, x_mask=x_mask, g=g) return o