mirror of https://github.com/coqui-ai/TTS.git
Add cond layer in decoder
This commit is contained in:
parent
8408b983b2
commit
8f21991a84
|
@ -117,7 +117,7 @@ class FFTransformerDecoder(nn.Module):
|
|||
self.postnet = nn.Conv1d(in_channels, out_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||
# TODO: handle multi-speaker
|
||||
# TODO: maybe pass g to every block
|
||||
x_mask = 1 if x_mask is None else x_mask
|
||||
o = self.transformer_block(x) * x_mask
|
||||
o = self.postnet(o) * x_mask
|
||||
|
@ -191,6 +191,9 @@ class Decoder(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
|
||||
if c_in_channels and c_in_channels != 0:
|
||||
self.cond = nn.Conv1d(c_in_channels, in_hidden_channels, 1)
|
||||
|
||||
if decoder_type.lower() == "relative_position_transformer":
|
||||
self.decoder = RelativePositionTransformerDecoder(
|
||||
in_channels=in_hidden_channels,
|
||||
|
@ -225,6 +228,9 @@ class Decoder(nn.Module):
|
|||
x_mask: [B, 1, T]
|
||||
g: [B, C_g, 1]
|
||||
"""
|
||||
# TODO: implement multi-speaker
|
||||
o = self.decoder(x, x_mask, g)
|
||||
# multi-speaker conditioning
|
||||
if hasattr(self, "cond") and self.cond is not None:
|
||||
g = self.cond(g)
|
||||
x = x + g
|
||||
o = self.decoder(x=x, x_mask=x_mask, g=g)
|
||||
return o
|
||||
|
|
Loading…
Reference in New Issue