mirror of https://github.com/coqui-ai/TTS.git
Add upsampling and downsampling to UNet
This commit is contained in:
parent
8676ab30d9
commit
5fd7ea93ea
|
@ -44,7 +44,7 @@ class ResNetBlock1D(nn.Module):
|
||||||
nn.Mish(),
|
nn.Mish(),
|
||||||
nn.Linear(time_embed_channels, out_channels)
|
nn.Linear(time_embed_channels, out_channels)
|
||||||
)
|
)
|
||||||
self.block_2 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups)
|
self.block_2 = ConvBlock1D(in_channels=out_channels, out_channels=out_channels, num_groups=num_groups)
|
||||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
|
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
|
||||||
|
|
||||||
def forward(self, x, mask, t):
|
def forward(self, x, mask, t):
|
||||||
|
@ -55,6 +55,24 @@ class ResNetBlock1D(nn.Module):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample1D(nn.Module):
|
||||||
|
def __init__(self, channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample1D(nn.Module):
|
||||||
|
def __init__(self, channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.ConvTranspose1d(in_channels=channels, out_channels=channels, kernel_size=4, stride=2, padding=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -77,21 +95,49 @@ class UNet(nn.Module):
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList([])
|
self.input_blocks = nn.ModuleList([])
|
||||||
block_in_channels = in_channels * 2
|
block_in_channels = in_channels * 2
|
||||||
for _ in range(num_blocks):
|
block_out_channels = model_channels
|
||||||
|
for level in range(num_blocks):
|
||||||
block = nn.ModuleList([])
|
block = nn.ModuleList([])
|
||||||
|
|
||||||
block.append(
|
block.append(
|
||||||
ResNetBlock1D(
|
ResNetBlock1D(
|
||||||
in_channels=block_in_channels,
|
in_channels=block_in_channels,
|
||||||
out_channels=model_channels,
|
out_channels=block_out_channels,
|
||||||
time_embed_channels=time_embed_channels
|
time_embed_channels=time_embed_channels
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if level != num_blocks - 1:
|
||||||
|
block.append(Downsample1D(block_out_channels))
|
||||||
|
else:
|
||||||
|
block.append(None)
|
||||||
|
|
||||||
|
block_in_channels = block_out_channels
|
||||||
self.input_blocks.append(block)
|
self.input_blocks.append(block)
|
||||||
|
|
||||||
self.middle_blocks = nn.ModuleList([])
|
self.middle_blocks = nn.ModuleList([])
|
||||||
|
|
||||||
self.output_blocks = nn.ModuleList([])
|
self.output_blocks = nn.ModuleList([])
|
||||||
|
block_in_channels = block_out_channels * 2
|
||||||
|
block_out_channels = model_channels
|
||||||
|
for level in range(num_blocks):
|
||||||
|
block = nn.ModuleList([])
|
||||||
|
|
||||||
|
block.append(
|
||||||
|
ResNetBlock1D(
|
||||||
|
in_channels=block_in_channels,
|
||||||
|
out_channels=block_out_channels,
|
||||||
|
time_embed_channels=time_embed_channels
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if level != num_blocks - 1:
|
||||||
|
block.append(Upsample1D(block_out_channels))
|
||||||
|
else:
|
||||||
|
block.append(None)
|
||||||
|
|
||||||
|
block_in_channels = block_out_channels * 2
|
||||||
|
self.output_blocks.append(block)
|
||||||
|
|
||||||
self.conv_block = ConvBlock1D(model_channels, model_channels)
|
self.conv_block = ConvBlock1D(model_channels, model_channels)
|
||||||
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)
|
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)
|
||||||
|
@ -102,15 +148,33 @@ class UNet(nn.Module):
|
||||||
|
|
||||||
x_t = pack([x_t, mean], "b * t")[0]
|
x_t = pack([x_t, mean], "b * t")[0]
|
||||||
|
|
||||||
|
hidden_states = []
|
||||||
|
mask_states = [mask]
|
||||||
|
|
||||||
for block in self.input_blocks:
|
for block in self.input_blocks:
|
||||||
res_net_block = block[0]
|
res_net_block, downsample = block
|
||||||
|
|
||||||
x_t = res_net_block(x_t, mask, t)
|
x_t = res_net_block(x_t, mask, t)
|
||||||
|
hidden_states.append(x_t)
|
||||||
|
|
||||||
|
if downsample is not None:
|
||||||
|
x_t = downsample(x_t * mask)
|
||||||
|
mask = mask[:, :, ::2]
|
||||||
|
mask_states.append(mask)
|
||||||
|
|
||||||
for _ in self.middle_blocks:
|
for _ in self.middle_blocks:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
for _ in self.output_blocks:
|
for block in self.output_blocks:
|
||||||
pass
|
res_net_block, upsample = block
|
||||||
|
|
||||||
|
x_t = pack([x_t, hidden_states.pop()], "b * t")[0]
|
||||||
|
mask = mask_states.pop()
|
||||||
|
x_t = res_net_block(x_t, mask, t)
|
||||||
|
|
||||||
|
if upsample is not None:
|
||||||
|
x_t = upsample(x_t * mask)
|
||||||
|
|
||||||
|
|
||||||
output = self.conv_block(x_t)
|
output = self.conv_block(x_t)
|
||||||
output = self.conv(x_t)
|
output = self.conv(x_t)
|
||||||
|
|
|
@ -11,7 +11,7 @@ class Decoder(nn.Module):
|
||||||
self.sigma_min = 1e-5
|
self.sigma_min = 1e-5
|
||||||
self.predictor = UNet(
|
self.predictor = UNet(
|
||||||
in_channels=80,
|
in_channels=80,
|
||||||
model_channels=160,
|
model_channels=256,
|
||||||
out_channels=80,
|
out_channels=80,
|
||||||
num_blocks=2
|
num_blocks=2
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue