mirror of https://github.com/coqui-ai/TTS.git
33 lines
965 B
Python
33 lines
965 B
Python
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
from TTS.tts.layers.matcha_tts.UNet import UNet
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.sigma_min = 1e-5
|
|
self.predictor = UNet(
|
|
in_channels=80,
|
|
model_channels=256,
|
|
out_channels=80,
|
|
num_blocks=2
|
|
)
|
|
|
|
def forward(self, x_1, mean, mask):
|
|
"""
|
|
Shapes:
|
|
- x_1: :math:`[B, C, T]`
|
|
- mean: :math:`[B, C ,T]`
|
|
- mask: :math:`[B, 1, T]`
|
|
"""
|
|
t = torch.rand([x_1.size(0), 1, 1], device=x_1.device, dtype=x_1.dtype)
|
|
x_0 = torch.randn_like(x_1)
|
|
x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1
|
|
u_t = x_1 - (1 - self.sigma_min) * x_0
|
|
v_t = self.predictor(x_t, mean, mask, t.squeeze())
|
|
loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1])
|
|
return loss
|