mirror of https://github.com/coqui-ai/TTS.git
Add conv block to UNet
This commit is contained in:
parent
b5467b8051
commit
0f7a7edb9b
|
@ -18,6 +18,23 @@ class PositionalEncoding(torch.nn.Module):
|
||||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
class ConvBlock1D(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, num_groups=8):
|
||||||
|
super().__init__()
|
||||||
|
self.block = nn.Sequential(
|
||||||
|
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||||
|
nn.GroupNorm(num_groups, out_channels),
|
||||||
|
nn.Mish()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
if mask is not None:
|
||||||
|
x = x * mask
|
||||||
|
output = self.block(x)
|
||||||
|
if mask is not None:
|
||||||
|
output = output * mask
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -42,6 +59,7 @@ class UNet(nn.Module):
|
||||||
self.middle_blocks = nn.ModuleList([])
|
self.middle_blocks = nn.ModuleList([])
|
||||||
self.output_blocks = nn.ModuleList([])
|
self.output_blocks = nn.ModuleList([])
|
||||||
|
|
||||||
|
self.conv_block = ConvBlock1D(model_channels, model_channels)
|
||||||
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)
|
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)
|
||||||
|
|
||||||
def forward(self, x_t, mean, mask, t):
|
def forward(self, x_t, mean, mask, t):
|
||||||
|
@ -59,6 +77,7 @@ class UNet(nn.Module):
|
||||||
for _ in self.output_blocks:
|
for _ in self.output_blocks:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
output = self.conv_block(x_t)
|
||||||
output = self.conv(x_t)
|
output = self.conv(x_t)
|
||||||
|
|
||||||
return output * mask
|
return output * mask
|
Loading…
Reference in New Issue