Add cond layer in decoder

This commit is contained in:
Eren Gölge 2022-04-04 09:44:20 +02:00
parent 8408b983b2
commit 8f21991a84
1 changed files with 9 additions and 3 deletions

View File

@ -117,7 +117,7 @@ class FFTransformerDecoder(nn.Module):
self.postnet = nn.Conv1d(in_channels, out_channels, 1) self.postnet = nn.Conv1d(in_channels, out_channels, 1)
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
# TODO: handle multi-speaker # TODO: maybe pass g to every block
x_mask = 1 if x_mask is None else x_mask x_mask = 1 if x_mask is None else x_mask
o = self.transformer_block(x) * x_mask o = self.transformer_block(x) * x_mask
o = self.postnet(o) * x_mask o = self.postnet(o) * x_mask
@ -191,6 +191,9 @@ class Decoder(nn.Module):
): ):
super().__init__() super().__init__()
if c_in_channels and c_in_channels != 0:
self.cond = nn.Conv1d(c_in_channels, in_hidden_channels, 1)
if decoder_type.lower() == "relative_position_transformer": if decoder_type.lower() == "relative_position_transformer":
self.decoder = RelativePositionTransformerDecoder( self.decoder = RelativePositionTransformerDecoder(
in_channels=in_hidden_channels, in_channels=in_hidden_channels,
@ -225,6 +228,9 @@ class Decoder(nn.Module):
x_mask: [B, 1, T] x_mask: [B, 1, T]
g: [B, C_g, 1] g: [B, C_g, 1]
""" """
# TODO: implement multi-speaker # multi-speaker conditioning
o = self.decoder(x, x_mask, g) if hasattr(self, "cond") and self.cond is not None:
g = self.cond(g)
x = x + g
o = self.decoder(x=x, x_mask=x_mask, g=g)
return o return o