mirror of https://github.com/coqui-ai/TTS.git
Add conv block to UNet
This commit is contained in:
parent
b5467b8051
commit
0f7a7edb9b
|
@ -18,6 +18,23 @@ class PositionalEncoding(torch.nn.Module):
|
|||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
class ConvBlock1D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_groups=8):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.GroupNorm(num_groups, out_channels),
|
||||
nn.Mish()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
if mask is not None:
|
||||
x = x * mask
|
||||
output = self.block(x)
|
||||
if mask is not None:
|
||||
output = output * mask
|
||||
return output
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(
|
||||
|
@ -42,6 +59,7 @@ class UNet(nn.Module):
|
|||
self.middle_blocks = nn.ModuleList([])
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
|
||||
self.conv_block = ConvBlock1D(model_channels, model_channels)
|
||||
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x_t, mean, mask, t):
|
||||
|
@ -59,6 +77,7 @@ class UNet(nn.Module):
|
|||
for _ in self.output_blocks:
|
||||
pass
|
||||
|
||||
output = self.conv_block(x_t)
|
||||
output = self.conv(x_t)
|
||||
|
||||
return output * mask
|
Loading…
Reference in New Issue