diff --git a/TTS/tts/layers/matcha_tts/UNet.py b/TTS/tts/layers/matcha_tts/UNet.py index 0183c787..142ef98f 100644 --- a/TTS/tts/layers/matcha_tts/UNet.py +++ b/TTS/tts/layers/matcha_tts/UNet.py @@ -44,7 +44,7 @@ class ResNetBlock1D(nn.Module): nn.Mish(), nn.Linear(time_embed_channels, out_channels) ) - self.block_2 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups) + self.block_2 = ConvBlock1D(in_channels=out_channels, out_channels=out_channels, num_groups=num_groups) self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1) def forward(self, x, mask, t): @@ -55,6 +55,24 @@ class ResNetBlock1D(nn.Module): return output +class Downsample1D(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1) + + def forward(self, x): + return self.conv(x) + + +class Upsample1D(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv = nn.ConvTranspose1d(in_channels=channels, out_channels=channels, kernel_size=4, stride=2, padding=1) + + def forward(self, x): + return self.conv(x) + + class UNet(nn.Module): def __init__( self, @@ -77,21 +95,49 @@ class UNet(nn.Module): self.input_blocks = nn.ModuleList([]) block_in_channels = in_channels * 2 - for _ in range(num_blocks): + block_out_channels = model_channels + for level in range(num_blocks): block = nn.ModuleList([]) block.append( ResNetBlock1D( in_channels=block_in_channels, - out_channels=model_channels, + out_channels=block_out_channels, time_embed_channels=time_embed_channels ) ) + if level != num_blocks - 1: + block.append(Downsample1D(block_out_channels)) + else: + block.append(None) + + block_in_channels = block_out_channels self.input_blocks.append(block) self.middle_blocks = nn.ModuleList([]) + self.output_blocks = nn.ModuleList([]) + block_in_channels = block_out_channels * 2 + block_out_channels = model_channels + for level in range(num_blocks): + block = nn.ModuleList([]) + + block.append( + ResNetBlock1D( + in_channels=block_in_channels, + out_channels=block_out_channels, + time_embed_channels=time_embed_channels + ) + ) + + if level != num_blocks - 1: + block.append(Upsample1D(block_out_channels)) + else: + block.append(None) + + block_in_channels = block_out_channels * 2 + self.output_blocks.append(block) self.conv_block = ConvBlock1D(model_channels, model_channels) self.conv = nn.Conv1d(model_channels, self.out_channels, 1) @@ -102,15 +148,33 @@ class UNet(nn.Module): x_t = pack([x_t, mean], "b * t")[0] + hidden_states = [] + mask_states = [mask] + for block in self.input_blocks: - res_net_block = block[0] + res_net_block, downsample = block + x_t = res_net_block(x_t, mask, t) + hidden_states.append(x_t) + + if downsample is not None: + x_t = downsample(x_t * mask) + mask = mask[:, :, ::2] + mask_states.append(mask) for _ in self.middle_blocks: pass - for _ in self.output_blocks: - pass + for block in self.output_blocks: + res_net_block, upsample = block + + x_t = pack([x_t, hidden_states.pop()], "b * t")[0] + mask = mask_states.pop() + x_t = res_net_block(x_t, mask, t) + + if upsample is not None: + x_t = upsample(x_t * mask) + output = self.conv_block(x_t) output = self.conv(x_t) diff --git a/TTS/tts/layers/matcha_tts/decoder.py b/TTS/tts/layers/matcha_tts/decoder.py index b80c190e..c87da9d5 100644 --- a/TTS/tts/layers/matcha_tts/decoder.py +++ b/TTS/tts/layers/matcha_tts/decoder.py @@ -11,7 +11,7 @@ class Decoder(nn.Module): self.sigma_min = 1e-5 self.predictor = UNet( in_channels=80, - model_channels=160, + model_channels=256, out_channels=80, num_blocks=2 )