mirror of https://github.com/coqui-ai/TTS.git
Revert "Merge remote-tracking branch 'subuday/matcha_tts' into dev"
This reverts commitf6a23c1d8a
, reversing changes made to275229a876
.
This commit is contained in:
parent
89b5322666
commit
29a0e189b2
|
@ -1,9 +0,0 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchaTTSConfig(BaseTTSConfig):
|
||||
model: str = "matcha_tts"
|
||||
num_chars: int = None
|
|
@ -1,299 +0,0 @@
|
|||
import math
|
||||
from einops import pack, rearrange
|
||||
import torch
|
||||
from torch import nn
|
||||
import conformer
|
||||
|
||||
|
||||
class PositionalEncoding(torch.nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
if x.ndim < 1:
|
||||
x = x.unsqueeze(0)
|
||||
emb = math.log(10000) / (self.channels // 2 - 1)
|
||||
emb = torch.exp(torch.arange(self.channels // 2, device=x.device).float() * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
class ConvBlock1D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_groups=8):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.GroupNorm(num_groups, out_channels),
|
||||
nn.Mish()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
if mask is not None:
|
||||
x = x * mask
|
||||
output = self.block(x)
|
||||
if mask is not None:
|
||||
output = output * mask
|
||||
return output
|
||||
|
||||
|
||||
class ResNetBlock1D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, time_embed_channels, num_groups=8):
|
||||
super().__init__()
|
||||
self.block_1 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Mish(),
|
||||
nn.Linear(time_embed_channels, out_channels)
|
||||
)
|
||||
self.block_2 = ConvBlock1D(in_channels=out_channels, out_channels=out_channels, num_groups=num_groups)
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x, mask, t):
|
||||
h = self.block_1(x, mask)
|
||||
h += self.mlp(t).unsqueeze(-1)
|
||||
h = self.block_2(h, mask)
|
||||
output = h + self.conv(x * mask)
|
||||
return output
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(in_channels=channels, out_channels=channels, kernel_size=4, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class ConformerBlock(conformer.ConformerBlock):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_head: int = 64,
|
||||
heads: int = 8,
|
||||
ff_mult: int = 4,
|
||||
conv_expansion_factor: int = 2,
|
||||
conv_kernel_size: int = 31,
|
||||
attn_dropout: float = 0.,
|
||||
ff_dropout: float = 0.,
|
||||
conv_dropout: float = 0.,
|
||||
conv_causal: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
ff_mult=ff_mult,
|
||||
conv_expansion_factor=conv_expansion_factor,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
attn_dropout=attn_dropout,
|
||||
ff_dropout=ff_dropout,
|
||||
conv_dropout=conv_dropout,
|
||||
conv_causal=conv_causal,
|
||||
)
|
||||
|
||||
def forward(self, x, mask,):
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
mask = rearrange(mask, "b 1 t -> b t")
|
||||
output = super().forward(x=x, mask=mask.bool())
|
||||
return rearrange(output, "b t c -> b c t")
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
out_channels: int,
|
||||
num_blocks: int,
|
||||
transformer_num_heads: int = 4,
|
||||
transformer_dim_head: int = 64,
|
||||
transformer_ff_mult: int = 1,
|
||||
transformer_conv_expansion_factor: int = 2,
|
||||
transformer_conv_kernel_size: int = 31,
|
||||
transformer_dropout: float = 0.05,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.time_encoder = PositionalEncoding(in_channels)
|
||||
time_embed_channels = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
nn.Linear(in_channels, time_embed_channels),
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_embed_channels, time_embed_channels),
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList([])
|
||||
block_in_channels = in_channels * 2
|
||||
block_out_channels = model_channels
|
||||
for level in range(num_blocks):
|
||||
block = nn.ModuleList([])
|
||||
|
||||
block.append(
|
||||
ResNetBlock1D(
|
||||
in_channels=block_in_channels,
|
||||
out_channels=block_out_channels,
|
||||
time_embed_channels=time_embed_channels
|
||||
)
|
||||
)
|
||||
|
||||
block.append(
|
||||
self._create_transformer_block(
|
||||
block_out_channels,
|
||||
dim_head=transformer_dim_head,
|
||||
num_heads=transformer_num_heads,
|
||||
ff_mult=transformer_ff_mult,
|
||||
conv_expansion_factor=transformer_conv_expansion_factor,
|
||||
conv_kernel_size=transformer_conv_kernel_size,
|
||||
dropout=transformer_dropout,
|
||||
)
|
||||
)
|
||||
|
||||
if level != num_blocks - 1:
|
||||
block.append(Downsample1D(block_out_channels))
|
||||
else:
|
||||
block.append(None)
|
||||
|
||||
block_in_channels = block_out_channels
|
||||
self.input_blocks.append(block)
|
||||
|
||||
self.middle_blocks = nn.ModuleList([])
|
||||
for i in range(2):
|
||||
block = nn.ModuleList([])
|
||||
|
||||
block.append(
|
||||
ResNetBlock1D(
|
||||
in_channels=block_out_channels,
|
||||
out_channels=block_out_channels,
|
||||
time_embed_channels=time_embed_channels
|
||||
)
|
||||
)
|
||||
|
||||
block.append(
|
||||
self._create_transformer_block(
|
||||
block_out_channels,
|
||||
dim_head=transformer_dim_head,
|
||||
num_heads=transformer_num_heads,
|
||||
ff_mult=transformer_ff_mult,
|
||||
conv_expansion_factor=transformer_conv_expansion_factor,
|
||||
conv_kernel_size=transformer_conv_kernel_size,
|
||||
dropout=transformer_dropout,
|
||||
)
|
||||
)
|
||||
|
||||
self.middle_blocks.append(block)
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
block_in_channels = block_out_channels * 2
|
||||
block_out_channels = model_channels
|
||||
for level in range(num_blocks):
|
||||
block = nn.ModuleList([])
|
||||
|
||||
block.append(
|
||||
ResNetBlock1D(
|
||||
in_channels=block_in_channels,
|
||||
out_channels=block_out_channels,
|
||||
time_embed_channels=time_embed_channels
|
||||
)
|
||||
)
|
||||
|
||||
block.append(
|
||||
self._create_transformer_block(
|
||||
block_out_channels,
|
||||
dim_head=transformer_dim_head,
|
||||
num_heads=transformer_num_heads,
|
||||
ff_mult=transformer_ff_mult,
|
||||
conv_expansion_factor=transformer_conv_expansion_factor,
|
||||
conv_kernel_size=transformer_conv_kernel_size,
|
||||
dropout=transformer_dropout,
|
||||
)
|
||||
)
|
||||
|
||||
if level != num_blocks - 1:
|
||||
block.append(Upsample1D(block_out_channels))
|
||||
else:
|
||||
block.append(None)
|
||||
|
||||
block_in_channels = block_out_channels * 2
|
||||
self.output_blocks.append(block)
|
||||
|
||||
self.conv_block = ConvBlock1D(model_channels, model_channels)
|
||||
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)
|
||||
|
||||
def _create_transformer_block(
|
||||
self,
|
||||
dim,
|
||||
dim_head: int = 64,
|
||||
num_heads: int = 4,
|
||||
ff_mult: int = 1,
|
||||
conv_expansion_factor: int = 2,
|
||||
conv_kernel_size: int = 31,
|
||||
dropout: float = 0.05,
|
||||
):
|
||||
return ConformerBlock(
|
||||
dim=dim,
|
||||
dim_head=dim_head,
|
||||
heads=num_heads,
|
||||
ff_mult=ff_mult,
|
||||
conv_expansion_factor=conv_expansion_factor,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
attn_dropout=dropout,
|
||||
ff_dropout=dropout,
|
||||
conv_dropout=dropout,
|
||||
conv_causal=False,
|
||||
)
|
||||
|
||||
def forward(self, x_t, mean, mask, t):
|
||||
t = self.time_encoder(t)
|
||||
t = self.time_embed(t)
|
||||
|
||||
x_t = pack([x_t, mean], "b * t")[0]
|
||||
|
||||
hidden_states = []
|
||||
mask_states = [mask]
|
||||
|
||||
for block in self.input_blocks:
|
||||
res_net_block, transformer, downsample = block
|
||||
|
||||
x_t = res_net_block(x_t, mask, t)
|
||||
x_t = transformer(x_t, mask)
|
||||
|
||||
hidden_states.append(x_t)
|
||||
|
||||
if downsample is not None:
|
||||
x_t = downsample(x_t * mask)
|
||||
mask = mask[:, :, ::2]
|
||||
mask_states.append(mask)
|
||||
|
||||
for block in self.middle_blocks:
|
||||
res_net_block, transformer = block
|
||||
mask = mask_states[-1]
|
||||
x_t = res_net_block(x_t, mask, t)
|
||||
x_t = transformer(x_t, mask)
|
||||
|
||||
for block in self.output_blocks:
|
||||
res_net_block, transformer, upsample = block
|
||||
|
||||
x_t = pack([x_t, hidden_states.pop()], "b * t")[0]
|
||||
mask = mask_states.pop()
|
||||
x_t = res_net_block(x_t, mask, t)
|
||||
x_t = transformer(x_t, mask)
|
||||
|
||||
if upsample is not None:
|
||||
x_t = upsample(x_t * mask)
|
||||
|
||||
output = self.conv_block(x_t)
|
||||
output = self.conv(x_t)
|
||||
|
||||
return output * mask
|
|
@ -1,32 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from TTS.tts.layers.matcha_tts.UNet import UNet
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sigma_min = 1e-5
|
||||
self.predictor = UNet(
|
||||
in_channels=80,
|
||||
model_channels=256,
|
||||
out_channels=80,
|
||||
num_blocks=2
|
||||
)
|
||||
|
||||
def forward(self, x_1, mean, mask):
|
||||
"""
|
||||
Shapes:
|
||||
- x_1: :math:`[B, C, T]`
|
||||
- mean: :math:`[B, C ,T]`
|
||||
- mask: :math:`[B, 1, T]`
|
||||
"""
|
||||
t = torch.rand([x_1.size(0), 1, 1], device=x_1.device, dtype=x_1.dtype)
|
||||
x_0 = torch.randn_like(x_1)
|
||||
x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1
|
||||
u_t = x_1 - (1 - self.sigma_min) * x_0
|
||||
v_t = self.predictor(x_t, mean, mask, t.squeeze())
|
||||
loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1])
|
||||
return loss
|
|
@ -1,85 +0,0 @@
|
|||
from dataclasses import field
|
||||
import math
|
||||
import torch
|
||||
|
||||
from TTS.tts.configs.matcha_tts import MatchaTTSConfig
|
||||
from TTS.tts.layers.glow_tts.encoder import Encoder
|
||||
from TTS.tts.layers.matcha_tts.decoder import Decoder
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import maximum_path, sequence_mask
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
|
||||
|
||||
class MatchaTTS(BaseTTS):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MatchaTTSConfig,
|
||||
ap: "AudioProcessor" = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
):
|
||||
super().__init__(config, ap, tokenizer)
|
||||
self.encoder = Encoder(
|
||||
self.config.num_chars,
|
||||
out_channels=80,
|
||||
hidden_channels=192,
|
||||
hidden_channels_dp=256,
|
||||
encoder_type='rel_pos_transformer',
|
||||
encoder_params={
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 6,
|
||||
"num_heads": 2,
|
||||
"hidden_channels_ffn": 768,
|
||||
}
|
||||
)
|
||||
|
||||
self.decoder = Decoder()
|
||||
|
||||
def forward(self, x, x_lengths, y, y_lengths):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor):
|
||||
Input text sequence ids. :math:`[B, T_en]`
|
||||
|
||||
x_lengths (torch.Tensor):
|
||||
Lengths of input text sequences. :math:`[B]`
|
||||
|
||||
y (torch.Tensor):
|
||||
Target mel-spectrogram frames. :math:`[B, T_de, C_mel]`
|
||||
|
||||
y_lengths (torch.Tensor):
|
||||
Lengths of target mel-spectrogram frames. :math:`[B]`
|
||||
"""
|
||||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
|
||||
o_mean, o_log_scale, o_log_dur, o_mask = self.encoder(x, x_lengths, g=None)
|
||||
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(o_mask.dtype)
|
||||
attn_mask = torch.unsqueeze(o_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
|
||||
with torch.no_grad():
|
||||
o_scale = torch.exp(-2 * o_log_scale)
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1)
|
||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (y**2))
|
||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), y)
|
||||
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1)
|
||||
logp = logp1 + logp2 + logp3 + logp4
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
|
||||
# Align encoded text with mel-spectrogram and get mu_y segment
|
||||
c_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
_ = self.decoder(x_1=y, mean=c_mean, mask=y_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "MatchaTTSConfig"):
|
||||
pass
|
||||
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
pass
|
|
@ -1,36 +0,0 @@
|
|||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.tts.configs.matcha_tts import MatchaTTSConfig
|
||||
from TTS.tts.models.matcha_tts import MatchaTTS
|
||||
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
c = MatchaTTSConfig()
|
||||
|
||||
|
||||
class TestMatchTTS(unittest.TestCase):
|
||||
@staticmethod
|
||||
def _create_inputs(batch_size=8):
|
||||
input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
mel_spec = torch.rand(batch_size, 30, c.audio["num_mels"]).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (batch_size,)).long().to(device)
|
||||
return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids
|
||||
|
||||
def _test_forward(self, batch_size):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(batch_size)
|
||||
config = MatchaTTSConfig(num_chars=32)
|
||||
model = MatchaTTS(config).to(device)
|
||||
|
||||
model.train()
|
||||
|
||||
model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
|
||||
|
||||
def test_forward(self):
|
||||
self._test_forward(1)
|
||||
self._test_forward(3)
|
Loading…
Reference in New Issue