Add upsampling and downsampling to UNet

This commit is contained in:
Subuday 2024-02-15 13:24:30 +00:00
parent 8676ab30d9
commit 5fd7ea93ea
2 changed files with 71 additions and 7 deletions

View File

@ -44,7 +44,7 @@ class ResNetBlock1D(nn.Module):
nn.Mish(),
nn.Linear(time_embed_channels, out_channels)
)
self.block_2 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups)
self.block_2 = ConvBlock1D(in_channels=out_channels, out_channels=out_channels, num_groups=num_groups)
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
def forward(self, x, mask, t):
@ -55,6 +55,24 @@ class ResNetBlock1D(nn.Module):
return output
class Downsample1D(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class Upsample1D(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.ConvTranspose1d(in_channels=channels, out_channels=channels, kernel_size=4, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def __init__(
self,
@ -77,21 +95,49 @@ class UNet(nn.Module):
self.input_blocks = nn.ModuleList([])
block_in_channels = in_channels * 2
for _ in range(num_blocks):
block_out_channels = model_channels
for level in range(num_blocks):
block = nn.ModuleList([])
block.append(
ResNetBlock1D(
in_channels=block_in_channels,
out_channels=model_channels,
out_channels=block_out_channels,
time_embed_channels=time_embed_channels
)
)
if level != num_blocks - 1:
block.append(Downsample1D(block_out_channels))
else:
block.append(None)
block_in_channels = block_out_channels
self.input_blocks.append(block)
self.middle_blocks = nn.ModuleList([])
self.output_blocks = nn.ModuleList([])
block_in_channels = block_out_channels * 2
block_out_channels = model_channels
for level in range(num_blocks):
block = nn.ModuleList([])
block.append(
ResNetBlock1D(
in_channels=block_in_channels,
out_channels=block_out_channels,
time_embed_channels=time_embed_channels
)
)
if level != num_blocks - 1:
block.append(Upsample1D(block_out_channels))
else:
block.append(None)
block_in_channels = block_out_channels * 2
self.output_blocks.append(block)
self.conv_block = ConvBlock1D(model_channels, model_channels)
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)
@ -102,15 +148,33 @@ class UNet(nn.Module):
x_t = pack([x_t, mean], "b * t")[0]
hidden_states = []
mask_states = [mask]
for block in self.input_blocks:
res_net_block = block[0]
res_net_block, downsample = block
x_t = res_net_block(x_t, mask, t)
hidden_states.append(x_t)
if downsample is not None:
x_t = downsample(x_t * mask)
mask = mask[:, :, ::2]
mask_states.append(mask)
for _ in self.middle_blocks:
pass
for _ in self.output_blocks:
pass
for block in self.output_blocks:
res_net_block, upsample = block
x_t = pack([x_t, hidden_states.pop()], "b * t")[0]
mask = mask_states.pop()
x_t = res_net_block(x_t, mask, t)
if upsample is not None:
x_t = upsample(x_t * mask)
output = self.conv_block(x_t)
output = self.conv(x_t)

View File

@ -11,7 +11,7 @@ class Decoder(nn.Module):
self.sigma_min = 1e-5
self.predictor = UNet(
in_channels=80,
model_channels=160,
model_channels=256,
out_channels=80,
num_blocks=2
)