mirror of https://github.com/coqui-ai/TTS.git
25 lines
738 B
Python
25 lines
738 B
Python
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.sigma_min = 1e-5
|
|
|
|
def forward(self, x_1, mean, mask):
|
|
"""
|
|
Shapes:
|
|
- x_1: :math:`[B, C, T]`
|
|
- mean: :math:`[B, C ,T]`
|
|
- mask: :math:`[B, 1, T]`
|
|
"""
|
|
t = torch.rand([x_1.size(0), 1, 1], device=x_1.device, dtype=x_1.dtype)
|
|
x_0 = torch.randn_like(x_1)
|
|
x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1
|
|
u_t = x_1 - (1 - self.sigma_min) * x_0
|
|
v_t = torch.randn_like(u_t)
|
|
loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1])
|
|
return loss
|