diff --git a/TTS/tts/layers/matcha_tts/UNet.py b/TTS/tts/layers/matcha_tts/UNet.py index 642a8545..8547bb9b 100644 --- a/TTS/tts/layers/matcha_tts/UNet.py +++ b/TTS/tts/layers/matcha_tts/UNet.py @@ -36,26 +36,58 @@ class ConvBlock1D(nn.Module): return output +class ResNetBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, time_embed_channels, num_groups=8): + super().__init__() + self.block_1 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups) + self.mlp = nn.Sequential( + nn.Mish(), + nn.Linear(time_embed_channels, out_channels) + ) + self.block_2 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups) + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1) + + def forward(self, x, mask, t): + h = self.block_1(x, mask) + h += self.mlp(t).unsqueeze(-1) + h = self.block_2(h, mask) + output = h + self.conv(x * mask) + return output + + class UNet(nn.Module): def __init__( self, in_channels: int, model_channels: int, out_channels: int, + num_blocks: int, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.time_encoder = PositionalEncoding(in_channels) - time_embed_dim = model_channels * 4 + time_embed_channels = model_channels * 4 self.time_embed = nn.Sequential( - nn.Linear(in_channels, time_embed_dim), + nn.Linear(in_channels, time_embed_channels), nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim), + nn.Linear(time_embed_channels, time_embed_channels), ) self.input_blocks = nn.ModuleList([]) + block_in_channels = in_channels + for _ in range(num_blocks): + block = nn.ModuleList([]) + + block.append( + ResNetBlock1D( + in_channels=block_in_channels, + out_channels=model_channels, + time_embed_channels=time_embed_channels + ) + ) + self.middle_blocks = nn.ModuleList([]) self.output_blocks = nn.ModuleList([]) @@ -68,8 +100,9 @@ class UNet(nn.Module): x_t = pack([x_t, mean], "b * t")[0] - for _ in self.input_blocks: - pass + for block in self.input_blocks: + res_net_block = block[0] + x_t = res_net_block(x_t, mask, t) for _ in self.middle_blocks: pass diff --git a/TTS/tts/layers/matcha_tts/decoder.py b/TTS/tts/layers/matcha_tts/decoder.py index e78d34cf..b80c190e 100644 --- a/TTS/tts/layers/matcha_tts/decoder.py +++ b/TTS/tts/layers/matcha_tts/decoder.py @@ -13,6 +13,7 @@ class Decoder(nn.Module): in_channels=80, model_channels=160, out_channels=80, + num_blocks=2 ) def forward(self, x_1, mean, mask):