Add ResNetBlock1D to UNet

This commit is contained in:
Subuday 2024-02-15 08:40:04 +00:00
parent 0f7a7edb9b
commit fd6c0afbbf
2 changed files with 39 additions and 5 deletions

View File

@ -36,26 +36,58 @@ class ConvBlock1D(nn.Module):
return output
class ResNetBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, time_embed_channels, num_groups=8):
super().__init__()
self.block_1 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups)
self.mlp = nn.Sequential(
nn.Mish(),
nn.Linear(time_embed_channels, out_channels)
)
self.block_2 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups)
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
def forward(self, x, mask, t):
h = self.block_1(x, mask)
h += self.mlp(t).unsqueeze(-1)
h = self.block_2(h, mask)
output = h + self.conv(x * mask)
return output
class UNet(nn.Module):
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
num_blocks: int,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.time_encoder = PositionalEncoding(in_channels)
time_embed_dim = model_channels * 4
time_embed_channels = model_channels * 4
self.time_embed = nn.Sequential(
nn.Linear(in_channels, time_embed_dim),
nn.Linear(in_channels, time_embed_channels),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
nn.Linear(time_embed_channels, time_embed_channels),
)
self.input_blocks = nn.ModuleList([])
block_in_channels = in_channels
for _ in range(num_blocks):
block = nn.ModuleList([])
block.append(
ResNetBlock1D(
in_channels=block_in_channels,
out_channels=model_channels,
time_embed_channels=time_embed_channels
)
)
self.middle_blocks = nn.ModuleList([])
self.output_blocks = nn.ModuleList([])
@ -68,8 +100,9 @@ class UNet(nn.Module):
x_t = pack([x_t, mean], "b * t")[0]
for _ in self.input_blocks:
pass
for block in self.input_blocks:
res_net_block = block[0]
x_t = res_net_block(x_t, mask, t)
for _ in self.middle_blocks:
pass

View File

@ -13,6 +13,7 @@ class Decoder(nn.Module):
in_channels=80,
model_channels=160,
out_channels=80,
num_blocks=2
)
def forward(self, x_1, mean, mask):