From 0f7a7edb9bdb7d75291144ee8b276331a4f16df5 Mon Sep 17 00:00:00 2001 From: Subuday Date: Wed, 14 Feb 2024 21:21:07 +0000 Subject: [PATCH] Add conv block to UNet --- TTS/tts/layers/matcha_tts/UNet.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/TTS/tts/layers/matcha_tts/UNet.py b/TTS/tts/layers/matcha_tts/UNet.py index 07616290..642a8545 100644 --- a/TTS/tts/layers/matcha_tts/UNet.py +++ b/TTS/tts/layers/matcha_tts/UNet.py @@ -18,6 +18,23 @@ class PositionalEncoding(torch.nn.Module): emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb +class ConvBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, num_groups=8): + super().__init__() + self.block = nn.Sequential( + nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1), + nn.GroupNorm(num_groups, out_channels), + nn.Mish() + ) + + def forward(self, x, mask=None): + if mask is not None: + x = x * mask + output = self.block(x) + if mask is not None: + output = output * mask + return output + class UNet(nn.Module): def __init__( @@ -42,6 +59,7 @@ class UNet(nn.Module): self.middle_blocks = nn.ModuleList([]) self.output_blocks = nn.ModuleList([]) + self.conv_block = ConvBlock1D(model_channels, model_channels) self.conv = nn.Conv1d(model_channels, self.out_channels, 1) def forward(self, x_t, mean, mask, t): @@ -59,6 +77,7 @@ class UNet(nn.Module): for _ in self.output_blocks: pass + output = self.conv_block(x_t) output = self.conv(x_t) return output * mask \ No newline at end of file