diff --git a/TTS/tts/layers/matcha_tts/UNet.py b/TTS/tts/layers/matcha_tts/UNet.py new file mode 100644 index 00000000..07616290 --- /dev/null +++ b/TTS/tts/layers/matcha_tts/UNet.py @@ -0,0 +1,64 @@ +import math +from einops import pack +import torch +from torch import nn + + +class PositionalEncoding(torch.nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + emb = math.log(10000) / (self.channels // 2 - 1) + emb = torch.exp(torch.arange(self.channels // 2, device=x.device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class UNet(nn.Module): + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_encoder = PositionalEncoding(in_channels) + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + nn.Linear(in_channels, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList([]) + self.middle_blocks = nn.ModuleList([]) + self.output_blocks = nn.ModuleList([]) + + self.conv = nn.Conv1d(model_channels, self.out_channels, 1) + + def forward(self, x_t, mean, mask, t): + t = self.time_encoder(t) + t = self.time_embed(t) + + x_t = pack([x_t, mean], "b * t")[0] + + for _ in self.input_blocks: + pass + + for _ in self.middle_blocks: + pass + + for _ in self.output_blocks: + pass + + output = self.conv(x_t) + + return output * mask \ No newline at end of file diff --git a/TTS/tts/layers/matcha_tts/decoder.py b/TTS/tts/layers/matcha_tts/decoder.py index de7f52dc..e78d34cf 100644 --- a/TTS/tts/layers/matcha_tts/decoder.py +++ b/TTS/tts/layers/matcha_tts/decoder.py @@ -2,11 +2,18 @@ import torch from torch import nn import torch.nn.functional as F +from TTS.tts.layers.matcha_tts.UNet import UNet + class Decoder(nn.Module): def __init__(self): super().__init__() self.sigma_min = 1e-5 + self.predictor = UNet( + in_channels=80, + model_channels=160, + out_channels=80, + ) def forward(self, x_1, mean, mask): """ @@ -19,6 +26,6 @@ class Decoder(nn.Module): x_0 = torch.randn_like(x_1) x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1 u_t = x_1 - (1 - self.sigma_min) * x_0 - v_t = torch.randn_like(u_t) + v_t = self.predictor(x_t, mean, mask, t.squeeze()) loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1]) return loss diff --git a/tests/tts_tests2/test_matcha_tts.py b/tests/tts_tests2/test_matcha_tts.py index bc94c6b4..5fbe9537 100644 --- a/tests/tts_tests2/test_matcha_tts.py +++ b/tests/tts_tests2/test_matcha_tts.py @@ -33,3 +33,4 @@ class TestMatchTTS(unittest.TestCase): def test_forward(self): self._test_forward(1) + self._test_forward(3)