mirror of https://github.com/coqui-ai/TTS.git
Add ResNetBlock1D to UNet
This commit is contained in:
parent
0f7a7edb9b
commit
fd6c0afbbf
|
@ -36,26 +36,58 @@ class ConvBlock1D(nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
class ResNetBlock1D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, time_embed_channels, num_groups=8):
|
||||
super().__init__()
|
||||
self.block_1 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Mish(),
|
||||
nn.Linear(time_embed_channels, out_channels)
|
||||
)
|
||||
self.block_2 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups)
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x, mask, t):
|
||||
h = self.block_1(x, mask)
|
||||
h += self.mlp(t).unsqueeze(-1)
|
||||
h = self.block_2(h, mask)
|
||||
output = h + self.conv(x * mask)
|
||||
return output
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
out_channels: int,
|
||||
num_blocks: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.time_encoder = PositionalEncoding(in_channels)
|
||||
time_embed_dim = model_channels * 4
|
||||
time_embed_channels = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
nn.Linear(in_channels, time_embed_dim),
|
||||
nn.Linear(in_channels, time_embed_channels),
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_embed_dim, time_embed_dim),
|
||||
nn.Linear(time_embed_channels, time_embed_channels),
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList([])
|
||||
block_in_channels = in_channels
|
||||
for _ in range(num_blocks):
|
||||
block = nn.ModuleList([])
|
||||
|
||||
block.append(
|
||||
ResNetBlock1D(
|
||||
in_channels=block_in_channels,
|
||||
out_channels=model_channels,
|
||||
time_embed_channels=time_embed_channels
|
||||
)
|
||||
)
|
||||
|
||||
self.middle_blocks = nn.ModuleList([])
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
|
||||
|
@ -68,8 +100,9 @@ class UNet(nn.Module):
|
|||
|
||||
x_t = pack([x_t, mean], "b * t")[0]
|
||||
|
||||
for _ in self.input_blocks:
|
||||
pass
|
||||
for block in self.input_blocks:
|
||||
res_net_block = block[0]
|
||||
x_t = res_net_block(x_t, mask, t)
|
||||
|
||||
for _ in self.middle_blocks:
|
||||
pass
|
||||
|
|
|
@ -13,6 +13,7 @@ class Decoder(nn.Module):
|
|||
in_channels=80,
|
||||
model_channels=160,
|
||||
out_channels=80,
|
||||
num_blocks=2
|
||||
)
|
||||
|
||||
def forward(self, x_1, mean, mask):
|
||||
|
|
Loading…
Reference in New Issue