mirror of https://github.com/coqui-ai/TTS.git
Implement model forward
This commit is contained in:
parent
8c4d0142b7
commit
7314b1cbec
|
@ -0,0 +1,24 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sigma_min = 1e-5
|
||||
|
||||
def forward(self, x_1, mean, mask):
|
||||
"""
|
||||
Shapes:
|
||||
- x_1: :math:`[B, C, T]`
|
||||
- mean: :math:`[B, C ,T]`
|
||||
- mask: :math:`[B, 1, T]`
|
||||
"""
|
||||
t = torch.rand([x_1.size(0), 1, 1], device=x_1.device, dtype=x_1.dtype)
|
||||
x_0 = torch.randn_like(x_1)
|
||||
x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1
|
||||
u_t = x_1 - (1 - self.sigma_min) * x_0
|
||||
v_t = torch.randn_like(u_t)
|
||||
loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1])
|
||||
return loss
|
|
@ -1,7 +1,12 @@
|
|||
from dataclasses import field
|
||||
import math
|
||||
import torch
|
||||
|
||||
from TTS.tts.configs.matcha_tts import MatchaTTSConfig
|
||||
from TTS.tts.layers.glow_tts.encoder import Encoder
|
||||
from TTS.tts.layers.matcha_tts.decoder import Decoder
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import maximum_path, sequence_mask
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
|
||||
|
||||
|
@ -14,9 +19,59 @@ class MatchaTTS(BaseTTS):
|
|||
tokenizer: "TTSTokenizer" = None,
|
||||
):
|
||||
super().__init__(config, ap, tokenizer)
|
||||
self.encoder = Encoder(
|
||||
self.config.num_chars,
|
||||
out_channels=80,
|
||||
hidden_channels=192,
|
||||
hidden_channels_dp=256,
|
||||
encoder_type='rel_pos_transformer',
|
||||
encoder_params={
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 6,
|
||||
"num_heads": 2,
|
||||
"hidden_channels_ffn": 768,
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
self.decoder = Decoder()
|
||||
|
||||
def forward(self, x, x_lengths, y, y_lengths):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor):
|
||||
Input text sequence ids. :math:`[B, T_en]`
|
||||
|
||||
x_lengths (torch.Tensor):
|
||||
Lengths of input text sequences. :math:`[B]`
|
||||
|
||||
y (torch.Tensor):
|
||||
Target mel-spectrogram frames. :math:`[B, T_de, C_mel]`
|
||||
|
||||
y_lengths (torch.Tensor):
|
||||
Lengths of target mel-spectrogram frames. :math:`[B]`
|
||||
"""
|
||||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
|
||||
o_mean, o_log_scale, o_log_dur, o_mask = self.encoder(x, x_lengths, g=None)
|
||||
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(o_mask.dtype)
|
||||
attn_mask = torch.unsqueeze(o_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
|
||||
with torch.no_grad():
|
||||
o_scale = torch.exp(-2 * o_log_scale)
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1)
|
||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (y**2))
|
||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), y)
|
||||
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1)
|
||||
logp = logp1 + logp2 + logp3 + logp4
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
|
||||
# Align encoded text with mel-spectrogram and get mu_y segment
|
||||
c_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
_ = self.decoder(x_1=y, mean=c_mean, mask=y_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self):
|
||||
|
|
|
@ -29,5 +29,7 @@ class TestMatchTTS(unittest.TestCase):
|
|||
|
||||
model.train()
|
||||
|
||||
model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
|
||||
|
||||
def test_forward(self):
|
||||
self._test_forward(1)
|
||||
|
|
Loading…
Reference in New Issue