Add UNet backbone

This commit is contained in:
Subuday 2024-02-12 21:44:29 +00:00
parent 7314b1cbec
commit b5467b8051
3 changed files with 73 additions and 1 deletions

View File

@ -0,0 +1,64 @@
import math
from einops import pack
import torch
from torch import nn
class PositionalEncoding(torch.nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
emb = math.log(10000) / (self.channels // 2 - 1)
emb = torch.exp(torch.arange(self.channels // 2, device=x.device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class UNet(nn.Module):
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.time_encoder = PositionalEncoding(in_channels)
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
nn.Linear(in_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList([])
self.middle_blocks = nn.ModuleList([])
self.output_blocks = nn.ModuleList([])
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)
def forward(self, x_t, mean, mask, t):
t = self.time_encoder(t)
t = self.time_embed(t)
x_t = pack([x_t, mean], "b * t")[0]
for _ in self.input_blocks:
pass
for _ in self.middle_blocks:
pass
for _ in self.output_blocks:
pass
output = self.conv(x_t)
return output * mask

View File

@ -2,11 +2,18 @@ import torch
from torch import nn
import torch.nn.functional as F
from TTS.tts.layers.matcha_tts.UNet import UNet
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.sigma_min = 1e-5
self.predictor = UNet(
in_channels=80,
model_channels=160,
out_channels=80,
)
def forward(self, x_1, mean, mask):
"""
@ -19,6 +26,6 @@ class Decoder(nn.Module):
x_0 = torch.randn_like(x_1)
x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1
u_t = x_1 - (1 - self.sigma_min) * x_0
v_t = torch.randn_like(u_t)
v_t = self.predictor(x_t, mean, mask, t.squeeze())
loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1])
return loss

View File

@ -33,3 +33,4 @@ class TestMatchTTS(unittest.TestCase):
def test_forward(self):
self._test_forward(1)
self._test_forward(3)