mirror of https://github.com/coqui-ai/TTS.git
Add cond layer in decoder
This commit is contained in:
parent
8408b983b2
commit
8f21991a84
|
@ -117,7 +117,7 @@ class FFTransformerDecoder(nn.Module):
|
||||||
self.postnet = nn.Conv1d(in_channels, out_channels, 1)
|
self.postnet = nn.Conv1d(in_channels, out_channels, 1)
|
||||||
|
|
||||||
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||||
# TODO: handle multi-speaker
|
# TODO: maybe pass g to every block
|
||||||
x_mask = 1 if x_mask is None else x_mask
|
x_mask = 1 if x_mask is None else x_mask
|
||||||
o = self.transformer_block(x) * x_mask
|
o = self.transformer_block(x) * x_mask
|
||||||
o = self.postnet(o) * x_mask
|
o = self.postnet(o) * x_mask
|
||||||
|
@ -191,6 +191,9 @@ class Decoder(nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if c_in_channels and c_in_channels != 0:
|
||||||
|
self.cond = nn.Conv1d(c_in_channels, in_hidden_channels, 1)
|
||||||
|
|
||||||
if decoder_type.lower() == "relative_position_transformer":
|
if decoder_type.lower() == "relative_position_transformer":
|
||||||
self.decoder = RelativePositionTransformerDecoder(
|
self.decoder = RelativePositionTransformerDecoder(
|
||||||
in_channels=in_hidden_channels,
|
in_channels=in_hidden_channels,
|
||||||
|
@ -225,6 +228,9 @@ class Decoder(nn.Module):
|
||||||
x_mask: [B, 1, T]
|
x_mask: [B, 1, T]
|
||||||
g: [B, C_g, 1]
|
g: [B, C_g, 1]
|
||||||
"""
|
"""
|
||||||
# TODO: implement multi-speaker
|
# multi-speaker conditioning
|
||||||
o = self.decoder(x, x_mask, g)
|
if hasattr(self, "cond") and self.cond is not None:
|
||||||
|
g = self.cond(g)
|
||||||
|
x = x + g
|
||||||
|
o = self.decoder(x=x, x_mask=x_mask, g=g)
|
||||||
return o
|
return o
|
||||||
|
|
Loading…
Reference in New Issue