mirror of https://github.com/coqui-ai/TTS.git
Add upsampling and downsampling to UNet
This commit is contained in:
parent
8676ab30d9
commit
5fd7ea93ea
|
@ -44,7 +44,7 @@ class ResNetBlock1D(nn.Module):
|
|||
nn.Mish(),
|
||||
nn.Linear(time_embed_channels, out_channels)
|
||||
)
|
||||
self.block_2 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups)
|
||||
self.block_2 = ConvBlock1D(in_channels=out_channels, out_channels=out_channels, num_groups=num_groups)
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x, mask, t):
|
||||
|
@ -55,6 +55,24 @@ class ResNetBlock1D(nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(in_channels=channels, out_channels=channels, kernel_size=4, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -77,21 +95,49 @@ class UNet(nn.Module):
|
|||
|
||||
self.input_blocks = nn.ModuleList([])
|
||||
block_in_channels = in_channels * 2
|
||||
for _ in range(num_blocks):
|
||||
block_out_channels = model_channels
|
||||
for level in range(num_blocks):
|
||||
block = nn.ModuleList([])
|
||||
|
||||
block.append(
|
||||
ResNetBlock1D(
|
||||
in_channels=block_in_channels,
|
||||
out_channels=model_channels,
|
||||
out_channels=block_out_channels,
|
||||
time_embed_channels=time_embed_channels
|
||||
)
|
||||
)
|
||||
|
||||
if level != num_blocks - 1:
|
||||
block.append(Downsample1D(block_out_channels))
|
||||
else:
|
||||
block.append(None)
|
||||
|
||||
block_in_channels = block_out_channels
|
||||
self.input_blocks.append(block)
|
||||
|
||||
self.middle_blocks = nn.ModuleList([])
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
block_in_channels = block_out_channels * 2
|
||||
block_out_channels = model_channels
|
||||
for level in range(num_blocks):
|
||||
block = nn.ModuleList([])
|
||||
|
||||
block.append(
|
||||
ResNetBlock1D(
|
||||
in_channels=block_in_channels,
|
||||
out_channels=block_out_channels,
|
||||
time_embed_channels=time_embed_channels
|
||||
)
|
||||
)
|
||||
|
||||
if level != num_blocks - 1:
|
||||
block.append(Upsample1D(block_out_channels))
|
||||
else:
|
||||
block.append(None)
|
||||
|
||||
block_in_channels = block_out_channels * 2
|
||||
self.output_blocks.append(block)
|
||||
|
||||
self.conv_block = ConvBlock1D(model_channels, model_channels)
|
||||
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)
|
||||
|
@ -102,15 +148,33 @@ class UNet(nn.Module):
|
|||
|
||||
x_t = pack([x_t, mean], "b * t")[0]
|
||||
|
||||
hidden_states = []
|
||||
mask_states = [mask]
|
||||
|
||||
for block in self.input_blocks:
|
||||
res_net_block = block[0]
|
||||
res_net_block, downsample = block
|
||||
|
||||
x_t = res_net_block(x_t, mask, t)
|
||||
hidden_states.append(x_t)
|
||||
|
||||
if downsample is not None:
|
||||
x_t = downsample(x_t * mask)
|
||||
mask = mask[:, :, ::2]
|
||||
mask_states.append(mask)
|
||||
|
||||
for _ in self.middle_blocks:
|
||||
pass
|
||||
|
||||
for _ in self.output_blocks:
|
||||
pass
|
||||
for block in self.output_blocks:
|
||||
res_net_block, upsample = block
|
||||
|
||||
x_t = pack([x_t, hidden_states.pop()], "b * t")[0]
|
||||
mask = mask_states.pop()
|
||||
x_t = res_net_block(x_t, mask, t)
|
||||
|
||||
if upsample is not None:
|
||||
x_t = upsample(x_t * mask)
|
||||
|
||||
|
||||
output = self.conv_block(x_t)
|
||||
output = self.conv(x_t)
|
||||
|
|
|
@ -11,7 +11,7 @@ class Decoder(nn.Module):
|
|||
self.sigma_min = 1e-5
|
||||
self.predictor = UNet(
|
||||
in_channels=80,
|
||||
model_channels=160,
|
||||
model_channels=256,
|
||||
out_channels=80,
|
||||
num_blocks=2
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue