From 7314b1cbec969647ba777af59ee96d0d475cb264 Mon Sep 17 00:00:00 2001 From: Subuday Date: Mon, 12 Feb 2024 19:39:22 +0000 Subject: [PATCH] Implement model forward --- TTS/tts/layers/matcha_tts/decoder.py | 24 +++++++++++ TTS/tts/models/matcha_tts.py | 59 +++++++++++++++++++++++++++- tests/tts_tests2/test_matcha_tts.py | 2 + 3 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 TTS/tts/layers/matcha_tts/decoder.py diff --git a/TTS/tts/layers/matcha_tts/decoder.py b/TTS/tts/layers/matcha_tts/decoder.py new file mode 100644 index 00000000..de7f52dc --- /dev/null +++ b/TTS/tts/layers/matcha_tts/decoder.py @@ -0,0 +1,24 @@ +import torch +from torch import nn +import torch.nn.functional as F + + +class Decoder(nn.Module): + def __init__(self): + super().__init__() + self.sigma_min = 1e-5 + + def forward(self, x_1, mean, mask): + """ + Shapes: + - x_1: :math:`[B, C, T]` + - mean: :math:`[B, C ,T]` + - mask: :math:`[B, 1, T]` + """ + t = torch.rand([x_1.size(0), 1, 1], device=x_1.device, dtype=x_1.dtype) + x_0 = torch.randn_like(x_1) + x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1 + u_t = x_1 - (1 - self.sigma_min) * x_0 + v_t = torch.randn_like(u_t) + loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1]) + return loss diff --git a/TTS/tts/models/matcha_tts.py b/TTS/tts/models/matcha_tts.py index 08c0022b..9bc3e0ff 100644 --- a/TTS/tts/models/matcha_tts.py +++ b/TTS/tts/models/matcha_tts.py @@ -1,7 +1,12 @@ +from dataclasses import field +import math import torch from TTS.tts.configs.matcha_tts import MatchaTTSConfig +from TTS.tts.layers.glow_tts.encoder import Encoder +from TTS.tts.layers.matcha_tts.decoder import Decoder from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import maximum_path, sequence_mask from TTS.tts.utils.text.tokenizer import TTSTokenizer @@ -14,9 +19,59 @@ class MatchaTTS(BaseTTS): tokenizer: "TTSTokenizer" = None, ): super().__init__(config, ap, tokenizer) + self.encoder = Encoder( + self.config.num_chars, + out_channels=80, + hidden_channels=192, + hidden_channels_dp=256, + encoder_type='rel_pos_transformer', + encoder_params={ + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 768, + } + ) - def forward(self): - pass + self.decoder = Decoder() + + def forward(self, x, x_lengths, y, y_lengths): + """ + Args: + x (torch.Tensor): + Input text sequence ids. :math:`[B, T_en]` + + x_lengths (torch.Tensor): + Lengths of input text sequences. :math:`[B]` + + y (torch.Tensor): + Target mel-spectrogram frames. :math:`[B, T_de, C_mel]` + + y_lengths (torch.Tensor): + Lengths of target mel-spectrogram frames. :math:`[B]` + """ + y = y.transpose(1, 2) + y_max_length = y.size(2) + + o_mean, o_log_scale, o_log_dur, o_mask = self.encoder(x, x_lengths, g=None) + + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(o_mask.dtype) + attn_mask = torch.unsqueeze(o_mask, -1) * torch.unsqueeze(y_mask, 2) + + with torch.no_grad(): + o_scale = torch.exp(-2 * o_log_scale) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (y**2)) + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), y) + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) + logp = logp1 + logp2 + logp3 + logp4 + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + + # Align encoded text with mel-spectrogram and get mu_y segment + c_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(1, 2) + + _ = self.decoder(x_1=y, mean=c_mean, mask=y_mask) @torch.no_grad() def inference(self): diff --git a/tests/tts_tests2/test_matcha_tts.py b/tests/tts_tests2/test_matcha_tts.py index 1939efbd..bc94c6b4 100644 --- a/tests/tts_tests2/test_matcha_tts.py +++ b/tests/tts_tests2/test_matcha_tts.py @@ -29,5 +29,7 @@ class TestMatchTTS(unittest.TestCase): model.train() + model.forward(input_dummy, input_lengths, mel_spec, mel_lengths) + def test_forward(self): self._test_forward(1)