Add conv block to UNet

This commit is contained in:
Subuday 2024-02-14 21:21:07 +00:00
parent b5467b8051
commit 0f7a7edb9b
1 changed files with 19 additions and 0 deletions

View File

@ -18,6 +18,23 @@ class PositionalEncoding(torch.nn.Module):
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class ConvBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, num_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
nn.GroupNorm(num_groups, out_channels),
nn.Mish()
)
def forward(self, x, mask=None):
if mask is not None:
x = x * mask
output = self.block(x)
if mask is not None:
output = output * mask
return output
class UNet(nn.Module):
def __init__(
@ -42,6 +59,7 @@ class UNet(nn.Module):
self.middle_blocks = nn.ModuleList([])
self.output_blocks = nn.ModuleList([])
self.conv_block = ConvBlock1D(model_channels, model_channels)
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)
def forward(self, x_t, mean, mask, t):
@ -59,6 +77,7 @@ class UNet(nn.Module):
for _ in self.output_blocks:
pass
output = self.conv_block(x_t)
output = self.conv(x_t)
return output * mask