Revert "Merge remote-tracking branch 'subuday/matcha_tts' into dev"

This reverts commit f6a23c1d8a, reversing
changes made to 275229a876.
This commit is contained in:
David Martin Rius 2024-03-05 22:38:59 +01:00
parent 89b5322666
commit 29a0e189b2
5 changed files with 0 additions and 461 deletions

View File

@ -1,9 +0,0 @@
from dataclasses import dataclass, field
from TTS.tts.configs.shared_configs import BaseTTSConfig
@dataclass
class MatchaTTSConfig(BaseTTSConfig):
model: str = "matcha_tts"
num_chars: int = None

View File

@ -1,299 +0,0 @@
import math
from einops import pack, rearrange
import torch
from torch import nn
import conformer
class PositionalEncoding(torch.nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
emb = math.log(10000) / (self.channels // 2 - 1)
emb = torch.exp(torch.arange(self.channels // 2, device=x.device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class ConvBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, num_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
nn.GroupNorm(num_groups, out_channels),
nn.Mish()
)
def forward(self, x, mask=None):
if mask is not None:
x = x * mask
output = self.block(x)
if mask is not None:
output = output * mask
return output
class ResNetBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, time_embed_channels, num_groups=8):
super().__init__()
self.block_1 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups)
self.mlp = nn.Sequential(
nn.Mish(),
nn.Linear(time_embed_channels, out_channels)
)
self.block_2 = ConvBlock1D(in_channels=out_channels, out_channels=out_channels, num_groups=num_groups)
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
def forward(self, x, mask, t):
h = self.block_1(x, mask)
h += self.mlp(t).unsqueeze(-1)
h = self.block_2(h, mask)
output = h + self.conv(x * mask)
return output
class Downsample1D(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class Upsample1D(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.ConvTranspose1d(in_channels=channels, out_channels=channels, kernel_size=4, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class ConformerBlock(conformer.ConformerBlock):
def __init__(
self,
dim: int,
dim_head: int = 64,
heads: int = 8,
ff_mult: int = 4,
conv_expansion_factor: int = 2,
conv_kernel_size: int = 31,
attn_dropout: float = 0.,
ff_dropout: float = 0.,
conv_dropout: float = 0.,
conv_causal: bool = False,
):
super().__init__(
dim=dim,
dim_head=dim_head,
heads=heads,
ff_mult=ff_mult,
conv_expansion_factor=conv_expansion_factor,
conv_kernel_size=conv_kernel_size,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
conv_dropout=conv_dropout,
conv_causal=conv_causal,
)
def forward(self, x, mask,):
x = rearrange(x, "b c t -> b t c")
mask = rearrange(mask, "b 1 t -> b t")
output = super().forward(x=x, mask=mask.bool())
return rearrange(output, "b t c -> b c t")
class UNet(nn.Module):
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
num_blocks: int,
transformer_num_heads: int = 4,
transformer_dim_head: int = 64,
transformer_ff_mult: int = 1,
transformer_conv_expansion_factor: int = 2,
transformer_conv_kernel_size: int = 31,
transformer_dropout: float = 0.05,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.time_encoder = PositionalEncoding(in_channels)
time_embed_channels = model_channels * 4
self.time_embed = nn.Sequential(
nn.Linear(in_channels, time_embed_channels),
nn.SiLU(),
nn.Linear(time_embed_channels, time_embed_channels),
)
self.input_blocks = nn.ModuleList([])
block_in_channels = in_channels * 2
block_out_channels = model_channels
for level in range(num_blocks):
block = nn.ModuleList([])
block.append(
ResNetBlock1D(
in_channels=block_in_channels,
out_channels=block_out_channels,
time_embed_channels=time_embed_channels
)
)
block.append(
self._create_transformer_block(
block_out_channels,
dim_head=transformer_dim_head,
num_heads=transformer_num_heads,
ff_mult=transformer_ff_mult,
conv_expansion_factor=transformer_conv_expansion_factor,
conv_kernel_size=transformer_conv_kernel_size,
dropout=transformer_dropout,
)
)
if level != num_blocks - 1:
block.append(Downsample1D(block_out_channels))
else:
block.append(None)
block_in_channels = block_out_channels
self.input_blocks.append(block)
self.middle_blocks = nn.ModuleList([])
for i in range(2):
block = nn.ModuleList([])
block.append(
ResNetBlock1D(
in_channels=block_out_channels,
out_channels=block_out_channels,
time_embed_channels=time_embed_channels
)
)
block.append(
self._create_transformer_block(
block_out_channels,
dim_head=transformer_dim_head,
num_heads=transformer_num_heads,
ff_mult=transformer_ff_mult,
conv_expansion_factor=transformer_conv_expansion_factor,
conv_kernel_size=transformer_conv_kernel_size,
dropout=transformer_dropout,
)
)
self.middle_blocks.append(block)
self.output_blocks = nn.ModuleList([])
block_in_channels = block_out_channels * 2
block_out_channels = model_channels
for level in range(num_blocks):
block = nn.ModuleList([])
block.append(
ResNetBlock1D(
in_channels=block_in_channels,
out_channels=block_out_channels,
time_embed_channels=time_embed_channels
)
)
block.append(
self._create_transformer_block(
block_out_channels,
dim_head=transformer_dim_head,
num_heads=transformer_num_heads,
ff_mult=transformer_ff_mult,
conv_expansion_factor=transformer_conv_expansion_factor,
conv_kernel_size=transformer_conv_kernel_size,
dropout=transformer_dropout,
)
)
if level != num_blocks - 1:
block.append(Upsample1D(block_out_channels))
else:
block.append(None)
block_in_channels = block_out_channels * 2
self.output_blocks.append(block)
self.conv_block = ConvBlock1D(model_channels, model_channels)
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)
def _create_transformer_block(
self,
dim,
dim_head: int = 64,
num_heads: int = 4,
ff_mult: int = 1,
conv_expansion_factor: int = 2,
conv_kernel_size: int = 31,
dropout: float = 0.05,
):
return ConformerBlock(
dim=dim,
dim_head=dim_head,
heads=num_heads,
ff_mult=ff_mult,
conv_expansion_factor=conv_expansion_factor,
conv_kernel_size=conv_kernel_size,
attn_dropout=dropout,
ff_dropout=dropout,
conv_dropout=dropout,
conv_causal=False,
)
def forward(self, x_t, mean, mask, t):
t = self.time_encoder(t)
t = self.time_embed(t)
x_t = pack([x_t, mean], "b * t")[0]
hidden_states = []
mask_states = [mask]
for block in self.input_blocks:
res_net_block, transformer, downsample = block
x_t = res_net_block(x_t, mask, t)
x_t = transformer(x_t, mask)
hidden_states.append(x_t)
if downsample is not None:
x_t = downsample(x_t * mask)
mask = mask[:, :, ::2]
mask_states.append(mask)
for block in self.middle_blocks:
res_net_block, transformer = block
mask = mask_states[-1]
x_t = res_net_block(x_t, mask, t)
x_t = transformer(x_t, mask)
for block in self.output_blocks:
res_net_block, transformer, upsample = block
x_t = pack([x_t, hidden_states.pop()], "b * t")[0]
mask = mask_states.pop()
x_t = res_net_block(x_t, mask, t)
x_t = transformer(x_t, mask)
if upsample is not None:
x_t = upsample(x_t * mask)
output = self.conv_block(x_t)
output = self.conv(x_t)
return output * mask

View File

@ -1,32 +0,0 @@
import torch
from torch import nn
import torch.nn.functional as F
from TTS.tts.layers.matcha_tts.UNet import UNet
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.sigma_min = 1e-5
self.predictor = UNet(
in_channels=80,
model_channels=256,
out_channels=80,
num_blocks=2
)
def forward(self, x_1, mean, mask):
"""
Shapes:
- x_1: :math:`[B, C, T]`
- mean: :math:`[B, C ,T]`
- mask: :math:`[B, 1, T]`
"""
t = torch.rand([x_1.size(0), 1, 1], device=x_1.device, dtype=x_1.dtype)
x_0 = torch.randn_like(x_1)
x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1
u_t = x_1 - (1 - self.sigma_min) * x_0
v_t = self.predictor(x_t, mean, mask, t.squeeze())
loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1])
return loss

View File

@ -1,85 +0,0 @@
from dataclasses import field
import math
import torch
from TTS.tts.configs.matcha_tts import MatchaTTSConfig
from TTS.tts.layers.glow_tts.encoder import Encoder
from TTS.tts.layers.matcha_tts.decoder import Decoder
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import maximum_path, sequence_mask
from TTS.tts.utils.text.tokenizer import TTSTokenizer
class MatchaTTS(BaseTTS):
def __init__(
self,
config: MatchaTTSConfig,
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
):
super().__init__(config, ap, tokenizer)
self.encoder = Encoder(
self.config.num_chars,
out_channels=80,
hidden_channels=192,
hidden_channels_dp=256,
encoder_type='rel_pos_transformer',
encoder_params={
"kernel_size": 3,
"dropout_p": 0.1,
"num_layers": 6,
"num_heads": 2,
"hidden_channels_ffn": 768,
}
)
self.decoder = Decoder()
def forward(self, x, x_lengths, y, y_lengths):
"""
Args:
x (torch.Tensor):
Input text sequence ids. :math:`[B, T_en]`
x_lengths (torch.Tensor):
Lengths of input text sequences. :math:`[B]`
y (torch.Tensor):
Target mel-spectrogram frames. :math:`[B, T_de, C_mel]`
y_lengths (torch.Tensor):
Lengths of target mel-spectrogram frames. :math:`[B]`
"""
y = y.transpose(1, 2)
y_max_length = y.size(2)
o_mean, o_log_scale, o_log_dur, o_mask = self.encoder(x, x_lengths, g=None)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(o_mask.dtype)
attn_mask = torch.unsqueeze(o_mask, -1) * torch.unsqueeze(y_mask, 2)
with torch.no_grad():
o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1)
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (y**2))
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), y)
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1)
logp = logp1 + logp2 + logp3 + logp4
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
# Align encoded text with mel-spectrogram and get mu_y segment
c_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(1, 2)
_ = self.decoder(x_1=y, mean=c_mean, mask=y_mask)
@torch.no_grad()
def inference(self):
pass
@staticmethod
def init_from_config(config: "MatchaTTSConfig"):
pass
def load_checkpoint(self, checkpoint_path):
pass

View File

@ -1,36 +0,0 @@
import unittest
import torch
from TTS.tts.configs.matcha_tts import MatchaTTSConfig
from TTS.tts.models.matcha_tts import MatchaTTS
torch.manual_seed(1)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
c = MatchaTTSConfig()
class TestMatchTTS(unittest.TestCase):
@staticmethod
def _create_inputs(batch_size=8):
input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device)
input_lengths[-1] = 128
mel_spec = torch.rand(batch_size, 30, c.audio["num_mels"]).to(device)
mel_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
speaker_ids = torch.randint(0, 5, (batch_size,)).long().to(device)
return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids
def _test_forward(self, batch_size):
input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(batch_size)
config = MatchaTTSConfig(num_chars=32)
model = MatchaTTS(config).to(device)
model.train()
model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
def test_forward(self):
self._test_forward(1)
self._test_forward(3)