Implement VITS model 🚀

VITS model implementation built on Glow TTS and HiFiGAN
layers.
This commit is contained in:
Eren Gölge 2021-08-09 08:00:43 +00:00
parent 060e746e21
commit c312acac7d
20 changed files with 2055 additions and 73 deletions

View File

@ -0,0 +1,60 @@
from dataclasses import dataclass, field
from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.vits import VitsArgs
@dataclass
class VitsConfig(BaseTTSConfig):
"""Defines parameters for VITS End2End TTS model.
Example:
>>> from TTS.tts.configs import VitsConfig
>>> config = VitsConfig()
"""
model: str = "vits"
# model specific params
model_args: VitsArgs = field(default_factory=VitsArgs)
# optimizer
grad_clip: float = field(default_factory=lambda: [5, 5])
lr_gen: float = 0.0002
lr_disc: float = 0.0002
lr_scheduler_gen: str = "ExponentialLR"
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
lr_scheduler_disc: str = "ExponentialLR"
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
scheduler_after_epoch: bool = True
optimizer: str = "AdamW"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})
# loss params
kl_loss_alpha: float = 1.0
disc_loss_alpha: float = 1.0
gen_loss_alpha: float = 1.0
feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0
# data loader params
return_wav: bool = True
compute_linear_spec: bool = True
# overrides
min_seq_len: int = 13
max_seq_len: int = 200
r: int = 1 # DO NOT CHANGE
add_blank: bool = True
# testing
test_sentences: List[str] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
]
)

View File

@ -191,6 +191,7 @@ class TTSDataset(Dataset):
else: else:
text, wav_file, speaker_name = item text, wav_file, speaker_name = item
attn = None attn = None
raw_text = text
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
@ -236,6 +237,7 @@ class TTSDataset(Dataset):
return self.load_data(self.rescue_item_idx) return self.load_data(self.rescue_item_idx)
sample = { sample = {
"raw_text": raw_text,
"text": text, "text": text,
"wav": wav, "wav": wav,
"attn": attn, "attn": attn,
@ -360,6 +362,7 @@ class TTSDataset(Dataset):
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing] wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing] item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
text = [batch[idx]["text"] for idx in ids_sorted_decreasing] text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing]
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing] speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
# get pre-computed d-vectors # get pre-computed d-vectors
@ -450,6 +453,7 @@ class TTSDataset(Dataset):
attns = torch.FloatTensor(attns).unsqueeze(1) attns = torch.FloatTensor(attns).unsqueeze(1)
else: else:
attns = None attns = None
# TODO: return dictionary
return ( return (
text, text,
text_lenghts, text_lenghts,
@ -463,6 +467,7 @@ class TTSDataset(Dataset):
speaker_ids, speaker_ids,
attns, attns,
wav_padded, wav_padded,
raw_text,
) )
raise TypeError( raise TypeError(

View File

@ -28,6 +28,31 @@ class LayerNorm(nn.Module):
return x return x
class LayerNorm2(nn.Module):
"""Layer norm for the 2nd dimension of the input using torch primitive.
Args:
channels (int): number of channels (2nd dimension) of the input.
eps (float): to prevent 0 division
Shapes:
- input: (B, C, T)
- output: (B, C, T)
"""
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = torch.nn.functional.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class TemporalBatchNorm1d(nn.BatchNorm1d): class TemporalBatchNorm1d(nn.BatchNorm1d):
"""Normalize each channel separately over time and batch.""" """Normalize each channel separately over time and batch."""

View File

@ -18,7 +18,7 @@ class DurationPredictor(nn.Module):
dropout_p (float): Dropout rate used after each conv layer. dropout_p (float): Dropout rate used after each conv layer.
""" """
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p): def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None):
super().__init__() super().__init__()
# class arguments # class arguments
self.in_channels = in_channels self.in_channels = in_channels
@ -33,13 +33,18 @@ class DurationPredictor(nn.Module):
self.norm_2 = LayerNorm(hidden_channels) self.norm_2 = LayerNorm(hidden_channels)
# output layer # output layer
self.proj = nn.Conv1d(hidden_channels, 1, 1) self.proj = nn.Conv1d(hidden_channels, 1, 1)
if cond_channels is not None and cond_channels != 0:
self.cond = nn.Conv1d(cond_channels, in_channels, 1)
def forward(self, x, x_mask): def forward(self, x, x_mask, g=None):
""" """
Shapes: Shapes:
- x: :math:`[B, C, T]` - x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]` - x_mask: :math:`[B, 1, T]`
- g: :math:`[B, C, 1]`
""" """
if g is not None:
x = x + self.cond(g)
x = self.conv_1(x * x_mask) x = self.conv_1(x * x_mask)
x = torch.relu(x) x = torch.relu(x)
x = self.norm_1(x) x = self.norm_1(x)

View File

@ -16,7 +16,7 @@ class ResidualConv1dLayerNormBlock(nn.Module):
:: ::
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|---------------> conv1d_1x1 -----------------------| |---------------> conv1d_1x1 ------------------|
Args: Args:
in_channels (int): number of input tensor channels. in_channels (int): number of input tensor channels.

View File

@ -4,7 +4,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from TTS.tts.layers.glow_tts.glow import LayerNorm from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
class RelativePositionMultiHeadAttention(nn.Module): class RelativePositionMultiHeadAttention(nn.Module):
@ -271,7 +271,7 @@ class FeedForwardNetwork(nn.Module):
dropout_p (float, optional): dropout rate. Defaults to 0. dropout_p (float, optional): dropout rate. Defaults to 0.
""" """
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0): def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0, causal=False):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -280,17 +280,46 @@ class FeedForwardNetwork(nn.Module):
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dropout_p = dropout_p self.dropout_p = dropout_p
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2) if causal:
self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size, padding=kernel_size // 2) self.padding = self._causal_padding
else:
self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size)
self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size)
self.dropout = nn.Dropout(dropout_p) self.dropout = nn.Dropout(dropout_p)
def forward(self, x, x_mask): def forward(self, x, x_mask):
x = self.conv_1(x * x_mask) x = self.conv_1(self.padding(x * x_mask))
x = torch.relu(x) x = torch.relu(x)
x = self.dropout(x) x = self.dropout(x)
x = self.conv_2(x * x_mask) x = self.conv_2(self.padding(x * x_mask))
return x * x_mask return x * x_mask
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, self._pad_shape(padding))
return x
def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, self._pad_shape(padding))
return x
@staticmethod
def _pad_shape(padding):
l = padding[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
class RelativePositionTransformer(nn.Module): class RelativePositionTransformer(nn.Module):
"""Transformer with Relative Potional Encoding. """Transformer with Relative Potional Encoding.
@ -310,20 +339,23 @@ class RelativePositionTransformer(nn.Module):
If default, relative encoding is disabled and it is a regular transformer. If default, relative encoding is disabled and it is a regular transformer.
Defaults to None. Defaults to None.
input_length (int, optional): input lenght to limit position encoding. Defaults to None. input_length (int, optional): input lenght to limit position encoding. Defaults to None.
layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm
primitive. Use type "2", type "1: is for backward compat. Defaults to "1".
""" """
def __init__( def __init__(
self, self,
in_channels, in_channels: int,
out_channels, out_channels: int,
hidden_channels, hidden_channels: int,
hidden_channels_ffn, hidden_channels_ffn: int,
num_heads, num_heads: int,
num_layers, num_layers: int,
kernel_size=1, kernel_size=1,
dropout_p=0.0, dropout_p=0.0,
rel_attn_window_size=None, rel_attn_window_size: int = None,
input_length=None, input_length: int = None,
layer_norm_type: str = "1",
): ):
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -351,7 +383,12 @@ class RelativePositionTransformer(nn.Module):
input_length=input_length, input_length=input_length,
) )
) )
self.norm_layers_1.append(LayerNorm(hidden_channels)) if layer_norm_type == "1":
self.norm_layers_1.append(LayerNorm(hidden_channels))
elif layer_norm_type == "2":
self.norm_layers_1.append(LayerNorm2(hidden_channels))
else:
raise ValueError(" [!] Unknown layer norm type")
if hidden_channels != out_channels and (idx + 1) == self.num_layers: if hidden_channels != out_channels and (idx + 1) == self.num_layers:
self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
@ -366,7 +403,12 @@ class RelativePositionTransformer(nn.Module):
) )
) )
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels)) if layer_norm_type == "1":
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
elif layer_norm_type == "2":
self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels))
else:
raise ValueError(" [!] Unknown layer norm type")
def forward(self, x, x_mask): def forward(self, x, x_mask):
""" """

View File

@ -2,11 +2,13 @@ import math
import numpy as np import numpy as np
import torch import torch
from coqpit import Coqpit
from torch import nn from torch import nn
from torch.nn import functional from torch.nn import functional
from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.ssim import ssim from TTS.tts.utils.ssim import ssim
from TTS.utils.audio import TorchSTFT
# pylint: disable=abstract-method # pylint: disable=abstract-method
@ -514,3 +516,142 @@ class AlignTTSLoss(nn.Module):
+ self.mdn_alpha * mdn_loss + self.mdn_alpha * mdn_loss
) )
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss} return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
class VitsGeneratorLoss(nn.Module):
def __init__(self, c: Coqpit):
super().__init__()
self.kl_loss_alpha = c.kl_loss_alpha
self.gen_loss_alpha = c.gen_loss_alpha
self.feat_loss_alpha = c.feat_loss_alpha
self.mel_loss_alpha = c.mel_loss_alpha
self.stft = TorchSTFT(
c.audio.fft_size,
c.audio.hop_length,
c.audio.win_length,
sample_rate=c.audio.sample_rate,
mel_fmin=c.audio.mel_fmin,
mel_fmax=c.audio.mel_fmax,
n_mels=c.audio.num_mels,
use_mel=True,
do_amp_to_db=True,
)
@staticmethod
def feature_loss(feats_real, feats_generated):
loss = 0
for dr, dg in zip(feats_real, feats_generated):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
@staticmethod
def generator_loss(scores_fake):
loss = 0
gen_losses = []
for dg in scores_fake:
dg = dg.float()
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
@staticmethod
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
"""
z_p, logs_q: [b, h, t_t]
m_p, logs_p: [b, h, t_t]
"""
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l
def forward(
self,
waveform,
waveform_hat,
z_p,
logs_q,
m_p,
logs_p,
z_len,
scores_disc_fake,
feats_disc_fake,
feats_disc_real,
):
"""
Shapes:
- wavefrom: :math:`[B, 1, T]`
- waveform_hat: :math:`[B, 1, T]`
- z_p: :math:`[B, C, T]`
- logs_q: :math:`[B, C, T]`
- m_p: :math:`[B, C, T]`
- logs_p: :math:`[B, C, T]`
- z_len: :math:`[B]`
- scores_disc_fake[i]: :math:`[B, C]`
- feats_disc_fake[i][j]: :math:`[B, C, T', P]`
- feats_disc_real[i][j]: :math:`[B, C, T', P]`
"""
loss = 0.0
return_dict = {}
z_mask = sequence_mask(z_len).float()
# compute mel spectrograms from the waveforms
mel = self.stft(waveform)
mel_hat = self.stft(waveform_hat)
# compute losses
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl
return_dict["loss_feat"] = loss_feat
return_dict["loss_mel"] = loss_mel
return_dict["loss"] = loss
return return_dict
class VitsDiscriminatorLoss(nn.Module):
def __init__(self, c: Coqpit):
super().__init__()
self.disc_loss_alpha = c.disc_loss_alpha
@staticmethod
def discriminator_loss(scores_real, scores_fake):
loss = 0
real_losses = []
fake_losses = []
for dr, dg in zip(scores_real, scores_fake):
dr = dr.float()
dg = dg.float()
real_loss = torch.mean((1 - dr) ** 2)
fake_loss = torch.mean(dg ** 2)
loss += real_loss + fake_loss
real_losses.append(real_loss.item())
fake_losses.append(fake_loss.item())
return loss, real_losses, fake_losses
def forward(self, scores_disc_real, scores_disc_fake):
loss = 0.0
return_dict = {}
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + loss_disc
return_dict["loss_disc"] = loss_disc
return_dict["loss"] = loss
return return_dict

View File

@ -0,0 +1,77 @@
import torch
from torch import nn
from torch.nn.modules.conv import Conv1d
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
class DiscriminatorS(torch.nn.Module):
"""HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN.
Args:
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
"""
def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
Tensor: discriminator scores.
List[Tensor]: list of features from the convolutiona layers.
"""
feat = []
for l in self.convs:
x = l(x)
x = torch.nn.functional.leaky_relu(x, 0.1)
feat.append(x)
x = self.conv_post(x)
feat.append(x)
x = torch.flatten(x, 1, -1)
return x, feat
class VitsDiscriminator(nn.Module):
"""VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator.
::
waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats
|--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^
Args:
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
"""
def __init__(self, use_spectral_norm=False):
super().__init__()
self.sd = DiscriminatorS(use_spectral_norm=use_spectral_norm)
self.mpd = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm)
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator.
"""
scores, feats = self.mpd(x)
score_sd, feats_sd = self.sd(x)
return scores + [score_sd], feats + [feats_sd]

View File

@ -0,0 +1,271 @@
import math
import torch
from torch import nn
from TTS.tts.layers.glow_tts.glow import WN
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.utils.data import sequence_mask
LRELU_SLOPE = 0.1
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class TextEncoder(nn.Module):
def __init__(
self,
n_vocab: int,
out_channels: int,
hidden_channels: int,
hidden_channels_ffn: int,
num_heads: int,
num_layers: int,
kernel_size: int,
dropout_p: float,
):
"""Text Encoder for VITS model.
Args:
n_vocab (int): Number of characters for the embedding layer.
out_channels (int): Number of channels for the output.
hidden_channels (int): Number of channels for the hidden layers.
hidden_channels_ffn (int): Number of channels for the convolutional layers.
num_heads (int): Number of attention heads for the Transformer layers.
num_layers (int): Number of Transformer layers.
kernel_size (int): Kernel size for the FFN layers in Transformer network.
dropout_p (float): Dropout rate for the Transformer layers.
"""
super().__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
self.encoder = RelativePositionTransformer(
in_channels=hidden_channels,
out_channels=hidden_channels,
hidden_channels=hidden_channels,
hidden_channels_ffn=hidden_channels_ffn,
num_heads=num_heads,
num_layers=num_layers,
kernel_size=kernel_size,
dropout_p=dropout_p,
layer_norm_type="2",
rel_attn_window_size=4,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths):
"""
Shapes:
- x: :math:`[B, T]`
- x_length: :math:`[B]`
"""
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return x, m, logs, x_mask
class ResidualCouplingBlock(nn.Module):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
dropout_p=0,
cond_channels=0,
mean_only=False,
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.half_channels = channels // 2
self.mean_only = mean_only
# input layer
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
# coupling layers
self.enc = WN(
hidden_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
dropout_p=dropout_p,
c_in_channels=cond_channels,
)
# output layer
# Initializing last layer to 0 makes the affine coupling layers
# do nothing at first. This helps with training stability
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- g: :math:`[B, C, 1]`
"""
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, log_scale = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
log_scale = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(log_scale) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(log_scale, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-log_scale) * x_mask
x = torch.cat([x0, x1], 1)
return x
class ResidualCouplingBlocks(nn.Module):
def __init__(
self,
channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
num_layers: int,
num_flows=4,
cond_channels=0,
):
"""Redisual Coupling blocks for VITS flow layers.
Args:
channels (int): Number of input and output tensor channels.
hidden_channels (int): Number of hidden network channels.
kernel_size (int): Kernel size of the WaveNet layers.
dilation_rate (int): Dilation rate of the WaveNet layers.
num_layers (int): Number of the WaveNet layers.
num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4.
cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0.
"""
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_layers = num_layers
self.num_flows = num_flows
self.cond_channels = cond_channels
self.flows = nn.ModuleList()
for _ in range(num_flows):
self.flows.append(
ResidualCouplingBlock(
channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
cond_channels=cond_channels,
mean_only=True,
)
)
def forward(self, x, x_mask, g=None, reverse=False):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- g: :math:`[B, C, 1]`
"""
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
x = torch.flip(x, [1])
else:
for flow in reversed(self.flows):
x = torch.flip(x, [1])
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class PosteriorEncoder(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
num_layers: int,
cond_channels=0,
):
"""Posterior Encoder of VITS model.
::
x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z
Args:
in_channels (int): Number of input tensor channels.
out_channels (int): Number of output tensor channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size of the WaveNet convolution layers.
dilation_rate (int): Dilation rate of the WaveNet layers.
num_layers (int): Number of the WaveNet layers.
cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0.
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_layers = num_layers
self.cond_channels = cond_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = WN(
hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_lengths: :math:`[B, 1]`
- g: :math:`[B, C, 1]`
"""
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
mean, log_scale = torch.split(stats, self.out_channels, dim=1)
z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask
return z, mean, log_scale, x_mask

View File

@ -0,0 +1,276 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from TTS.tts.layers.generic.normalization import LayerNorm2
from TTS.tts.layers.vits.transforms import piecewise_rational_quadratic_transform
class DilatedDepthSeparableConv(nn.Module):
def __init__(self, channels, kernel_size, num_layers, dropout_p=0.0) -> torch.tensor:
"""Dilated Depth-wise Separable Convolution module.
::
x |-> DDSConv(x) -> LayerNorm(x) -> GeLU(x) -> Conv1x1(x) -> LayerNorm(x) -> GeLU(x) -> + -> o
|-------------------------------------------------------------------------------------^
Args:
channels ([type]): [description]
kernel_size ([type]): [description]
num_layers ([type]): [description]
dropout_p (float, optional): [description]. Defaults to 0.0.
Returns:
torch.tensor: Network output masked by the input sequence mask.
"""
super().__init__()
self.num_layers = num_layers
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(num_layers):
dilation = kernel_size ** i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm2(channels))
self.norms_2.append(LayerNorm2(channels))
self.dropout = nn.Dropout(dropout_p)
def forward(self, x, x_mask, g=None):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
"""
if g is not None:
x = x + g
for i in range(self.num_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.dropout(y)
x = x + y
return x * x_mask
class ElementwiseAffine(nn.Module):
"""Element-wise affine transform like no-population stats BatchNorm alternative.
Args:
channels (int): Number of input tensor channels.
"""
def __init__(self, channels):
super().__init__()
self.translation = nn.Parameter(torch.zeros(channels, 1))
self.log_scale = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs): # pylint: disable=unused-argument
if not reverse:
y = (x * torch.exp(self.log_scale) + self.translation) * x_mask
logdet = torch.sum(self.log_scale * x_mask, [1, 2])
return y, logdet
x = (x - self.translation) * torch.exp(-self.log_scale) * x_mask
return x
class ConvFlow(nn.Module):
"""Dilated depth separable convolutional based spline flow.
Args:
in_channels (int): Number of input tensor channels.
hidden_channels (int): Number of in network channels.
kernel_size (int): Convolutional kernel size.
num_layers (int): Number of convolutional layers.
num_bins (int, optional): Number of spline bins. Defaults to 10.
tail_bound (float, optional): Tail bound for PRQT. Defaults to 5.0.
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
kernel_size: int,
num_layers: int,
num_bins=10,
tail_bound=5.0,
):
super().__init__()
self.num_bins = num_bins
self.tail_bound = tail_bound
self.hidden_channels = hidden_channels
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers, dropout_p=0.0)
self.proj = nn.Conv1d(hidden_channels, self.half_channels * (num_bins * 3 - 1), 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.hidden_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.hidden_channels)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
return x
class StochasticDurationPredictor(nn.Module):
"""Stochastic duration predictor with Spline Flows.
It applies Variational Dequantization and Variationsl Data Augmentation.
Paper:
SDP: https://arxiv.org/pdf/2106.06103.pdf
Spline Flow: https://arxiv.org/abs/1906.04032
::
## Inference
x -> TextCondEncoder() -> Flow() -> dr_hat
noise ----------------------^
## Training
|---------------------|
x -> TextCondEncoder() -> + -> PosteriorEncoder() -> split() -> z_u, z_v -> (d - z_u) -> concat() -> Flow() -> noise
d -> DurCondEncoder() -> ^ |
|------------------------------------------------------------------------------|
Args:
in_channels (int): Number of input tensor channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size of convolutional layers.
dropout_p (float): Dropout rate.
num_flows (int, optional): Number of flow blocks. Defaults to 4.
cond_channels (int, optional): Number of channels of conditioning tensor. Defaults to 0.
"""
def __init__(
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0
):
super().__init__()
# condition encoder text
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
self.proj = nn.Conv1d(hidden_channels, hidden_channels, 1)
# posterior encoder
self.flows = nn.ModuleList()
self.flows.append(ElementwiseAffine(2))
self.flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)]
# condition encoder duration
self.post_pre = nn.Conv1d(1, hidden_channels, 1)
self.post_convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
self.post_proj = nn.Conv1d(hidden_channels, hidden_channels, 1)
# flow layers
self.post_flows = nn.ModuleList()
self.post_flows.append(ElementwiseAffine(2))
self.post_flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)]
if cond_channels != 0 and cond_channels is not None:
self.cond = nn.Conv1d(cond_channels, hidden_channels, 1)
def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- dr: :math:`[B, 1, T]`
- g: :math:`[B, C]`
"""
# condition encoder text
x = self.pre(x)
if g is not None:
x = x + self.cond(g)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
if not reverse:
flows = self.flows
assert dr is not None
# condition encoder duration
h = self.post_pre(dr)
h = self.post_convs(h, x_mask)
h = self.post_proj(h) * x_mask
noise = torch.rand(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = noise
# posterior encoder
logdet_tot_q = 0.0
for idx, flow in enumerate(self.post_flows):
z_q, logdet_q = flow(z_q, x_mask, g=(x + h))
logdet_tot_q = logdet_tot_q + logdet_q
if idx > 0:
z_q = torch.flip(z_q, [1])
z_u, z_v = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (dr - u) * x_mask
# posterior encoder - neg log likelihood
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
nll_posterior_encoder = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (noise ** 2)) * x_mask, [1, 2]) - logdet_tot_q
)
z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask
logdet_tot = torch.sum(-z0, [1, 2])
z = torch.cat([z0, z_v], 1)
# flow layers
for idx, flow in enumerate(flows):
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
if idx > 0:
z = torch.flip(z, [1])
# flow layers - neg log likelihood
nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
return nll_flow_layers + nll_posterior_encoder
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.rand(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = torch.flip(z, [1])
z = flow(z, x_mask, g=x, reverse=reverse)
z0, _ = torch.split(z, [1, 1], 1)
logw = z0
return logw

View File

@ -0,0 +1,203 @@
# adopted from https://github.com/bayesiains/nflows
import numpy as np
import torch
from torch.nn import functional as F
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs,
)
return outputs, logabsdet
def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)
return outputs, logabsdet
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet

View File

@ -212,13 +212,22 @@ class BaseTTS(BaseModel):
else None, else None,
) )
if ( if config.use_phonemes and config.compute_input_seq_cache:
config.use_phonemes if hasattr(self, "eval_data_items") and is_eval:
and config.compute_input_seq_cache dataset.items = self.eval_data_items
and not os.path.exists(dataset.phoneme_cache_path) elif hasattr(self, "train_data_items") and not is_eval:
): dataset.items = self.train_data_items
# precompute phonemes to have a better estimate of sequence lengths. else:
dataset.compute_input_seq(config.num_loader_workers) # precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(config.num_loader_workers)
# TODO: find a more efficient solution
# cheap hack - store items in the model state to avoid recomputing when reinit the dataset
if is_eval:
self.eval_data_items = dataset.items
else:
self.train_data_items = dataset.items
dataset.sort_items() dataset.sort_items()
sampler = DistributedSampler(dataset) if num_gpus > 1 else None sampler = DistributedSampler(dataset) if num_gpus > 1 else None

758
TTS/tts/models/vits.py Normal file
View File

@ -0,0 +1,758 @@
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
import torch
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.audio import AudioProcessor
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4):
"""Segment each sample in a batch based on the provided segment indices"""
segments = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
index_start = segment_indices[i]
index_end = index_start + segment_size
segments[i] = x[i, :, index_start:index_end]
return segments
def rand_segment(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4):
"""Create random segments based on the input lengths."""
B, _, T = x.size()
if x_lengths is None:
x_lengths = T
max_idxs = x_lengths - segment_size + 1
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
ids_str = (torch.rand([B]).type_as(x) * max_idxs).long()
ret = segment(x, ids_str, segment_size)
return ret, ids_str
@dataclass
class VitsArgs(Coqpit):
"""VITS model arguments.
Args:
num_chars (int):
Number of characters in the vocabulary. Defaults to 100.
out_channels (int):
Number of output channels. Defaults to 513.
spec_segment_size (int):
Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`.
hidden_channels (int):
Number of hidden channels of the model. Defaults to 192.
hidden_channels_ffn_text_encoder (int):
Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256.
num_heads_text_encoder (int):
Number of attention heads of the text encoder transformer. Defaults to 2.
num_layers_text_encoder (int):
Number of transformer layers in the text encoder. Defaults to 6.
kernel_size_text_encoder (int):
Kernel size of the text encoder transformer FFN layers. Defaults to 3.
dropout_p_text_encoder (float):
Dropout rate of the text encoder. Defaults to 0.1.
kernel_size_posterior_encoder (int):
Kernel size of the posterior encoder's WaveNet layers. Defaults to 5.
dilatation_posterior_encoder (int):
Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1.
num_layers_posterior_encoder (int):
Number of posterior encoder's WaveNet layers. Defaults to 16.
kernel_size_flow (int):
Kernel size of the Residual Coupling layers of the flow network. Defaults to 5.
dilatation_flow (int):
Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1.
num_layers_flow (int):
Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6.
resblock_type_decoder (str):
Type of the residual block in the decoder network. Defaults to "1".
resblock_kernel_sizes_decoder (List[int]):
Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`.
resblock_dilation_sizes_decoder (List[List[int]]):
Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`.
upsample_rates_decoder (List[int]):
Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these
values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`.
upsample_initial_channel_decoder (int):
Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512.
upsample_kernel_sizes_decoder (List[int]):
Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`.
use_sdp (int):
Use Stochastic Duration Predictor. Defaults to True.
noise_scale (float):
Noise scale used for the sample noise tensor in training. Defaults to 1.0.
inference_noise_scale (float):
Noise scale used for the sample noise tensor in inference. Defaults to 0.667.
length_scale (int):
Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1.
noise_scale_dp (float):
Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0.
inference_noise_scale_dp (float):
Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8.
max_inference_len (int):
Maximum inference length to limit the memory use. Defaults to None.
init_discriminator (bool):
Initialize the disciminator network if set True. Set False for inference. Defaults to True.
use_spectral_norm_disriminator (bool):
Use spectral normalization over weight norm in the discriminator. Defaults to False.
use_speaker_embedding (bool):
Enable/Disable speaker embedding for multi-speaker models. Defaults to False.
num_speakers (int):
Number of speakers for the speaker embedding layer. Defaults to 0.
speakers_file (str):
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
speaker_embedding_channels (int):
Number of speaker embedding channels. Defaults to 256.
use_d_vector_file (bool):
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
d_vector_dim (int):
Number of d-vector channels. Defaults to 0.
detach_dp_input (bool):
Detach duration predictor's input from the network for stopping the gradients. Defaults to True.
"""
num_chars: int = 100
out_channels: int = 513
spec_segment_size: int = 32
hidden_channels: int = 192
hidden_channels_ffn_text_encoder: int = 768
num_heads_text_encoder: int = 2
num_layers_text_encoder: int = 6
kernel_size_text_encoder: int = 3
dropout_p_text_encoder: int = 0.1
kernel_size_posterior_encoder: int = 5
dilation_rate_posterior_encoder: int = 1
num_layers_posterior_encoder: int = 16
kernel_size_flow: int = 5
dilation_rate_flow: int = 1
num_layers_flow: int = 4
resblock_type_decoder: int = "1"
resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11])
resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
upsample_initial_channel_decoder: int = 512
upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
use_sdp: int = True
noise_scale: float = 1.0
inference_noise_scale: float = 0.667
length_scale: int = 1
noise_scale_dp: float = 1.0
inference_noise_scale_dp: float = 0.8
max_inference_len: int = None
init_discriminator: bool = True
use_spectral_norm_disriminator: bool = False
use_speaker_embedding: bool = False
num_speakers: int = 0
speakers_file: str = None
speaker_embedding_channels: int = 256
use_d_vector_file: bool = False
d_vector_dim: int = 0
detach_dp_input: bool = True
class Vits(BaseTTS):
"""VITS TTS model
Paper::
https://arxiv.org/pdf/2106.06103.pdf
Paper Abstract::
Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel
sampling have been proposed, but their sample quality does not match that of two-stage TTS systems.
In this work, we present a parallel endto-end TTS method that generates more natural sounding audio than
current two-stage models. Our method adopts variational inference augmented with normalizing flows and
an adversarial training process, which improves the expressive power of generative modeling. We also propose a
stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the
uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the
natural one-to-many relationship in which a text input can be spoken in multiple ways
with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS)
on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly
available TTS systems and achieves a MOS comparable to ground truth.
Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments.
Examples:
>>> from TTS.tts.configs import VitsConfig
>>> from TTS.tts.models.vits import Vits
>>> config = VitsConfig()
>>> model = Vits(config)
"""
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit):
super().__init__()
self.END2END = True
if config.__class__.__name__ == "VitsConfig":
# loading from VitsConfig
if "num_chars" not in config:
_, self.config, num_chars = self.get_characters(config)
config.model_args.num_chars = num_chars
else:
self.config = config
config.model_args.num_chars = config.num_chars
args = self.config.model_args
elif isinstance(config, VitsArgs):
# loading from VitsArgs
self.config = config
args = config
else:
raise ValueError("config must be either a VitsConfig or VitsArgs")
self.args = args
self.init_multispeaker(config)
self.length_scale = args.length_scale
self.noise_scale = args.noise_scale
self.inference_noise_scale = args.inference_noise_scale
self.inference_noise_scale_dp = args.inference_noise_scale_dp
self.noise_scale_dp = args.noise_scale_dp
self.max_inference_len = args.max_inference_len
self.spec_segment_size = args.spec_segment_size
self.text_encoder = TextEncoder(
args.num_chars,
args.hidden_channels,
args.hidden_channels,
args.hidden_channels_ffn_text_encoder,
args.num_heads_text_encoder,
args.num_layers_text_encoder,
args.kernel_size_text_encoder,
args.dropout_p_text_encoder,
)
self.posterior_encoder = PosteriorEncoder(
args.out_channels,
args.hidden_channels,
args.hidden_channels,
kernel_size=args.kernel_size_posterior_encoder,
dilation_rate=args.dilation_rate_posterior_encoder,
num_layers=args.num_layers_posterior_encoder,
cond_channels=self.embedded_speaker_dim,
)
self.flow = ResidualCouplingBlocks(
args.hidden_channels,
args.hidden_channels,
kernel_size=args.kernel_size_flow,
dilation_rate=args.dilation_rate_flow,
num_layers=args.num_layers_flow,
cond_channels=self.embedded_speaker_dim,
)
if args.use_sdp:
self.duration_predictor = StochasticDurationPredictor(
args.hidden_channels, 192, 3, 0.5, 4, cond_channels=self.embedded_speaker_dim
)
else:
self.duration_predictor = DurationPredictor(
args.hidden_channels, 256, 3, 0.5, cond_channels=self.embedded_speaker_dim
)
self.waveform_decoder = HifiganGenerator(
args.hidden_channels,
1,
args.resblock_type_decoder,
args.resblock_dilation_sizes_decoder,
args.resblock_kernel_sizes_decoder,
args.upsample_kernel_sizes_decoder,
args.upsample_initial_channel_decoder,
args.upsample_rates_decoder,
inference_padding=0,
cond_channels=self.embedded_speaker_dim,
conv_pre_weight_norm=False,
conv_post_weight_norm=False,
conv_post_bias=False,
)
if args.init_discriminator:
self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator)
def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
If you need a different behaviour, override this function for your model.
Args:
config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
"""
if hasattr(config, "model_args"):
config = config.model_args
self.embedded_speaker_dim = 0
# init speaker manager
self.speaker_manager = get_speaker_manager(config, data=data)
if config.num_speakers > 0 and self.speaker_manager.num_speakers == 0:
self.speaker_manager.num_speakers = config.num_speakers
self.num_speakers = self.speaker_manager.num_speakers
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
self.embedded_speaker_dim = config.speaker_embedding_channels
self.emb_g = nn.Embedding(config.num_speakers, config.speaker_embedding_channels)
# init d-vector usage
if config.use_d_vector_file:
self.embedded_speaker_dim = config.d_vector_dim
@staticmethod
def _set_cond_input(aux_input: Dict):
"""Set the speaker conditioning input based on the multi-speaker mode."""
sid, g = None, None
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
sid = aux_input["speaker_ids"]
if sid.ndim == 0:
sid = sid.unsqueeze_(0)
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
g = aux_input["d_vectors"]
return sid, g
def forward(
self,
x: torch.tensor,
x_lengths: torch.tensor,
y: torch.tensor,
y_lengths: torch.tensor,
aux_input={"d_vectors": None, "speaker_ids": None},
) -> Dict:
"""Forward pass of the model.
Args:
x (torch.tensor): Batch of input character sequence IDs.
x_lengths (torch.tensor): Batch of input character sequence lengths.
y (torch.tensor): Batch of input spectrograms.
y_lengths (torch.tensor): Batch of input spectrogram lengths.
aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.
Returns:
Dict: model outputs keyed by the output name.
Shapes:
- x: :math:`[B, T_seq]`
- x_lengths: :math:`[B]`
- y: :math:`[B, C, T_spec]`
- y_lengths: :math:`[B]`
- d_vectors: :math:`[B, C, 1]`
- speaker_ids: :math:`[B]`
"""
outputs = {}
sid, g = self._set_cond_input(aux_input)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
# speaker embedding
if self.num_speakers > 1 and sid is not None:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
# flow layers
z_p = self.flow(z, y_mask, g=g)
# find the alignment path
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
with torch.no_grad():
o_scale = torch.exp(-2 * logs_p)
# logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
# logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
# duration predictor
attn_durations = attn.sum(3)
if self.args.use_sdp:
nll_duration = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
)
nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask))
outputs["nll_duration"] = nll_duration
else:
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
log_durations = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
)
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
outputs["loss_duration"] = loss_duration
# expand prior
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
# select a random feature segment for the waveform decoder
z_slice, slice_ids = rand_segment(z, y_lengths, self.spec_segment_size)
o = self.waveform_decoder(z_slice, g=g)
outputs.update(
{
"model_outputs": o,
"alignments": attn.squeeze(1),
"slice_ids": slice_ids,
"z": z,
"z_p": z_p,
"m_p": m_p,
"logs_p": logs_p,
"m_q": m_q,
"logs_q": logs_q,
}
)
return outputs
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
"""
Shapes:
- x: :math:`[B, T_seq]`
- d_vectors: :math:`[B, C, 1]`
- speaker_ids: :math:`[B]`
"""
sid, g = self._set_cond_input(aux_input)
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
if self.num_speakers > 0 and sid:
g = self.emb_g(sid).unsqueeze(-1)
if self.args.use_sdp:
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp)
else:
logw = self.duration_predictor(x, x_mask, g=g)
w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
return outputs
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
"""TODO: create an end-point for voice conversion"""
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
g_src = self.emb_g(sid_src).unsqueeze(-1)
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
z, _, _, y_mask = self.enc_q(y, y_lengths, g=g_src)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
return o_hat, y_mask, (z, z_p, z_hat)
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Perform a single training step. Run the model forward pass and compute losses.
Args:
batch (Dict): Input tensors.
criterion (nn.Module): Loss layer designed for the model.
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
Returns:
Tuple[Dict, Dict]: Model ouputs and computed losses.
"""
# pylint: disable=attribute-defined-outside-init
if optimizer_idx not in [0, 1]:
raise ValueError(" [!] Unexpected `optimizer_idx`.")
if optimizer_idx == 0:
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
mel_lengths = batch["mel_lengths"]
linear_input = batch["linear_input"]
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
waveform = batch["waveform"]
# generator pass
outputs = self.forward(
text_input,
text_lengths,
linear_input.transpose(1, 2),
mel_lengths,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
)
# cache tensors for the discriminator
self.y_disc_cache = None
self.wav_seg_disc_cache = None
self.y_disc_cache = outputs["model_outputs"]
wav_seg = segment(
waveform.transpose(1, 2),
outputs["slice_ids"] * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
)
self.wav_seg_disc_cache = wav_seg
outputs["waveform_seg"] = wav_seg
# compute discriminator scores and features
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(outputs["model_outputs"])
_, outputs["feats_disc_real"] = self.disc(wav_seg)
# compute losses
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion[optimizer_idx](
waveform_hat=outputs["model_outputs"].float(),
waveform=wav_seg.float(),
z_p=outputs["z_p"].float(),
logs_q=outputs["logs_q"].float(),
m_p=outputs["m_p"].float(),
logs_p=outputs["logs_p"].float(),
z_len=mel_lengths,
scores_disc_fake=outputs["scores_disc_fake"],
feats_disc_fake=outputs["feats_disc_fake"],
feats_disc_real=outputs["feats_disc_real"],
)
# handle the duration loss
if self.args.use_sdp:
loss_dict["nll_duration"] = outputs["nll_duration"]
loss_dict["loss"] += outputs["nll_duration"]
else:
loss_dict["loss_duration"] = outputs["loss_duration"]
loss_dict["loss"] += outputs["nll_duration"]
elif optimizer_idx == 1:
# discriminator pass
outputs = {}
# compute scores and features
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(self.y_disc_cache.detach())
outputs["scores_disc_real"], outputs["feats_disc_real"] = self.disc(self.wav_seg_disc_cache)
# compute loss
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion[optimizer_idx](
outputs["scores_disc_real"],
outputs["scores_disc_fake"],
)
return outputs, loss_dict
def train_log(
self, ap: AudioProcessor, batch: Dict, outputs: List, name_prefix="train"
): # pylint: disable=no-self-use
"""Create visualizations and waveform examples.
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
be projected onto Tensorboard.
Args:
ap (AudioProcessor): audio processor used at training.
batch (Dict): Model inputs used at the previous training step.
outputs (Dict): Model outputs generated at the previoud training step.
Returns:
Tuple[Dict, np.ndarray]: training plots and output waveform.
"""
y_hat = outputs[0]["model_outputs"]
y = outputs[0]["waveform_seg"]
figures = plot_results(y_hat, y, ap, name_prefix)
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
audios = {f"{name_prefix}/audio": sample_voice}
alignments = outputs[0]["alignments"]
align_img = alignments[0].data.cpu().numpy().T
figures.update(
{
"alignment": plot_alignment(align_img, output_fig=False),
}
)
return figures, audios
@torch.no_grad()
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
return self.train_step(batch, criterion, optimizer_idx)
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
return self.train_log(ap, batch, outputs, "eval")
@torch.no_grad()
def test_run(self, ap) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self.get_aux_input()
for idx, sen in enumerate(test_sentences):
wav, alignment, _, _ = synthesis(
self,
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
).values()
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
return test_figures, test_audios
def get_optimizer(self) -> List:
"""Initiate and return the GAN optimizers based on the config parameters.
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
Returns:
List: optimizers.
"""
self.disc.requires_grad_(False)
gen_parameters = filter(lambda p: p.requires_grad, self.parameters())
self.disc.requires_grad_(True)
optimizer1 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
)
optimizer2 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
return [optimizer1, optimizer2]
def get_lr(self) -> List:
"""Set the initial learning rates for each optimizer.
Returns:
List: learning rates for each optimizer.
"""
return [self.config.lr_gen, self.config.lr_disc]
def get_scheduler(self, optimizer) -> List:
"""Set the schedulers for each optimizer.
Args:
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
Returns:
List: Schedulers, one for each optimizer.
"""
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler1, scheduler2]
def get_criterion(self):
"""Get criterions for each optimizer. The index in the output list matches the optimizer idx used in
`train_step()`"""
from TTS.tts.layers.losses import ( # pylint: disable=import-outside-toplevel
VitsDiscriminatorLoss,
VitsGeneratorLoss,
)
return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)]
@staticmethod
def make_symbols(config):
"""Create a custom arrangement of symbols used by the model. The output list of symbols propagate along the
whole training and inference steps."""
_pad = config.characters["pad"]
_punctuations = config.characters["punctuations"]
_letters = config.characters["characters"]
_letters_ipa = config.characters["phonemes"]
symbols = [_pad] + list(_punctuations) + list(_letters)
if config.use_phonemes:
symbols += list(_letters_ipa)
return symbols
@staticmethod
def get_characters(config: Coqpit):
if config.characters is not None:
symbols = Vits.make_symbols(config)
else:
from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel
parse_symbols,
phonemes,
symbols,
)
config.characters = parse_symbols()
if config.use_phonemes:
symbols = phonemes
num_chars = len(symbols) + getattr(config, "add_blank", False)
return symbols, config, num_chars
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
"""Load the model checkpoint and setup for training or inference"""
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training

View File

@ -81,7 +81,6 @@ def text2phone(text, language, use_espeak_phonemes=False):
# Fix a few phonemes # Fix a few phonemes
ph = ph.translate(GRUUT_TRANS_TABLE) ph = ph.translate(GRUUT_TRANS_TABLE)
# print(" > Phonemes: {}".format(ph))
return ph return ph
raise ValueError(f" [!] Language {language} is not supported for phonemization.") raise ValueError(f" [!] Language {language} is not supported for phonemization.")
@ -116,6 +115,7 @@ def phoneme_to_sequence(
use_espeak_phonemes: bool = False, use_espeak_phonemes: bool = False,
) -> List[int]: ) -> List[int]:
"""Converts a string of phonemes to a sequence of IDs. """Converts a string of phonemes to a sequence of IDs.
If `custom_symbols` is provided, it will override the default symbols.
Args: Args:
text (str): string to convert to a sequence text (str): string to convert to a sequence
@ -132,12 +132,11 @@ def phoneme_to_sequence(
# pylint: disable=global-statement # pylint: disable=global-statement
global _phonemes_to_id, _phonemes global _phonemes_to_id, _phonemes
if tp: if custom_symbols is not None:
_, _phonemes = make_symbols(**tp)
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
elif custom_symbols is not None:
_phonemes = custom_symbols _phonemes = custom_symbols
_phonemes_to_id = {s: i for i, s in enumerate(custom_symbols)} elif tp:
_, _phonemes = make_symbols(**tp)
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
sequence = [] sequence = []
clean_text = _clean_text(text, cleaner_names) clean_text = _clean_text(text, cleaner_names)
@ -155,16 +154,19 @@ def phoneme_to_sequence(
return sequence return sequence
def sequence_to_phoneme(sequence, tp=None, add_blank=False): def sequence_to_phoneme(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List["str"] = None):
# pylint: disable=global-statement # pylint: disable=global-statement
"""Converts a sequence of IDs back to a string""" """Converts a sequence of IDs back to a string"""
global _id_to_phonemes, _phonemes global _id_to_phonemes, _phonemes
if add_blank: if add_blank:
sequence = list(filter(lambda x: x != len(_phonemes), sequence)) sequence = list(filter(lambda x: x != len(_phonemes), sequence))
result = "" result = ""
if tp:
if custom_symbols is not None:
_phonemes = custom_symbols
elif tp:
_, _phonemes = make_symbols(**tp) _, _phonemes = make_symbols(**tp)
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)} _id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
for symbol_id in sequence: for symbol_id in sequence:
if symbol_id in _id_to_phonemes: if symbol_id in _id_to_phonemes:
@ -177,6 +179,7 @@ def text_to_sequence(
text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False
) -> List[int]: ) -> List[int]:
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text. """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
If `custom_symbols` is provided, it will override the default symbols.
Args: Args:
text (str): string to convert to a sequence text (str): string to convert to a sequence
@ -189,12 +192,12 @@ def text_to_sequence(
""" """
# pylint: disable=global-statement # pylint: disable=global-statement
global _symbol_to_id, _symbols global _symbol_to_id, _symbols
if tp:
_symbols, _ = make_symbols(**tp) if custom_symbols is not None:
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
elif custom_symbols is not None:
_symbols = custom_symbols _symbols = custom_symbols
_symbol_to_id = {s: i for i, s in enumerate(custom_symbols)} elif tp:
_symbols, _ = make_symbols(**tp)
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
sequence = [] sequence = []
@ -213,16 +216,18 @@ def text_to_sequence(
return sequence return sequence
def sequence_to_text(sequence, tp=None, add_blank=False): def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List[str] = None):
"""Converts a sequence of IDs back to a string""" """Converts a sequence of IDs back to a string"""
# pylint: disable=global-statement # pylint: disable=global-statement
global _id_to_symbol, _symbols global _id_to_symbol, _symbols
if add_blank: if add_blank:
sequence = list(filter(lambda x: x != len(_symbols), sequence)) sequence = list(filter(lambda x: x != len(_symbols), sequence))
if tp: if custom_symbols is not None:
_symbols = custom_symbols
elif tp:
_symbols, _ = make_symbols(**tp) _symbols, _ = make_symbols(**tp)
_id_to_symbol = {i: s for i, s in enumerate(_symbols)} _id_to_symbol = {i: s for i, s in enumerate(_symbols)}
result = "" result = ""
for symbol_id in sequence: for symbol_id in sequence:

View File

@ -96,10 +96,12 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
) )
self.mel_basis = torch.from_numpy(mel_basis).float() self.mel_basis = torch.from_numpy(mel_basis).float()
def _amp_to_db(self, x, spec_gain=1.0): @staticmethod
def _amp_to_db(x, spec_gain=1.0):
return torch.log(torch.clamp(x, min=1e-5) * spec_gain) return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
def _db_to_amp(self, x, spec_gain=1.0): @staticmethod
def _db_to_amp(x, spec_gain=1.0):
return torch.exp(x) / spec_gain return torch.exp(x) / spec_gain

View File

@ -33,10 +33,10 @@ class DiscriminatorP(torch.nn.Module):
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
self.convs = nn.ModuleList( self.convs = nn.ModuleList(
[ [
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
] ]
) )
@ -81,15 +81,15 @@ class MultiPeriodDiscriminator(torch.nn.Module):
Periods are suggested to be prime numbers to reduce the overlap between each discriminator. Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
""" """
def __init__(self): def __init__(self, use_spectral_norm=False):
super().__init__() super().__init__()
self.discriminators = nn.ModuleList( self.discriminators = nn.ModuleList(
[ [
DiscriminatorP(2), DiscriminatorP(2, use_spectral_norm=use_spectral_norm),
DiscriminatorP(3), DiscriminatorP(3, use_spectral_norm=use_spectral_norm),
DiscriminatorP(5), DiscriminatorP(5, use_spectral_norm=use_spectral_norm),
DiscriminatorP(7), DiscriminatorP(7, use_spectral_norm=use_spectral_norm),
DiscriminatorP(11), DiscriminatorP(11, use_spectral_norm=use_spectral_norm),
] ]
) )
@ -99,7 +99,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
x (Tensor): input waveform. x (Tensor): input waveform.
Returns: Returns:
[List[Tensor]]: list of scores from each discriminator. [List[Tensor]]: list of scores from each discriminator.
[List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer. [List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer.
Shapes: Shapes:

View File

@ -170,6 +170,10 @@ class HifiganGenerator(torch.nn.Module):
upsample_initial_channel, upsample_initial_channel,
upsample_factors, upsample_factors,
inference_padding=5, inference_padding=5,
cond_channels=0,
conv_pre_weight_norm=True,
conv_post_weight_norm=True,
conv_post_bias=True,
): ):
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
@ -218,12 +222,21 @@ class HifiganGenerator(torch.nn.Module):
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d)) self.resblocks.append(resblock(ch, k, d))
# post convolution layer # post convolution layer
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3)) self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
if cond_channels > 0:
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
def forward(self, x): if not conv_pre_weight_norm:
remove_weight_norm(self.conv_pre)
if not conv_post_weight_norm:
remove_weight_norm(self.conv_post)
def forward(self, x, g=None):
""" """
Args: Args:
x (Tensor): conditioning input tensor. x (Tensor): feature input tensor.
g (Tensor): global conditioning input tensor.
Returns: Returns:
Tensor: output waveform. Tensor: output waveform.
@ -233,6 +246,8 @@ class HifiganGenerator(torch.nn.Module):
Tensor: [B, 1, T] Tensor: [B, 1, T]
""" """
o = self.conv_pre(x) o = self.conv_pre(x)
if hasattr(self, "cond_layer"):
o = o + self.cond_layer(g)
for i in range(self.num_upsamples): for i in range(self.num_upsamples):
o = F.leaky_relu(o, LRELU_SLOPE) o = F.leaky_relu(o, LRELU_SLOPE)
o = self.ups[i](o) o = self.ups[i](o)

View File

@ -1,3 +1,5 @@
from typing import List
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -10,18 +12,35 @@ LRELU_SLOPE = 0.1
class UnivnetGenerator(torch.nn.Module): class UnivnetGenerator(torch.nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels: int,
out_channels, out_channels: int,
hidden_channels, hidden_channels: int,
cond_channels, cond_channels: int,
upsample_factors, upsample_factors: List[int],
lvc_layers_each_block, lvc_layers_each_block: int,
lvc_kernel_size, lvc_kernel_size: int,
kpnet_hidden_channels, kpnet_hidden_channels: int,
kpnet_conv_size, kpnet_conv_size: int,
dropout, dropout: float,
use_weight_norm=True, use_weight_norm=True,
): ):
"""Univnet Generator network.
Paper: https://arxiv.org/pdf/2106.07889.pdf
Args:
in_channels (int): Number of input tensor channels.
out_channels (int): Number of channels of the output tensor.
hidden_channels (int): Number of hidden network channels.
cond_channels (int): Number of channels of the conditioning tensors.
upsample_factors (List[int]): List of uplsample factors for the upsampling layers.
lvc_layers_each_block (int): Number of LVC layers in each block.
lvc_kernel_size (int): Kernel size of the LVC layers.
kpnet_hidden_channels (int): Number of hidden channels in the key-point network.
kpnet_conv_size (int): Number of convolution channels in the key-point network.
dropout (float): Dropout rate.
use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True.
"""
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels

View File

@ -1,8 +1,11 @@
from typing import Dict
import numpy as np import numpy as np
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from TTS.tts.utils.visual import plot_spectrogram from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
def interpolate_vocoder_input(scale_factor, spec): def interpolate_vocoder_input(scale_factor, spec):
@ -26,12 +29,24 @@ def interpolate_vocoder_input(scale_factor, spec):
return spec return spec
def plot_results(y_hat, y, ap, name_prefix): def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
"""Plot vocoder model results""" """Plot the predicted and the real waveform and their spectrograms.
Args:
y_hat (torch.tensor): Predicted waveform.
y (torch.tensor): Real waveform.
ap (AudioProcessor): Audio processor used to process the waveform.
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
Returns:
Dict: output figures keyed by the name of the figures.
""" """Plot vocoder model results"""
if name_prefix is None:
name_prefix = ""
# select an instance from batch # select an instance from batch
y_hat = y_hat[0].squeeze(0).detach().cpu().numpy() y_hat = y_hat[0].squeeze().detach().cpu().numpy()
y = y[0].squeeze(0).detach().cpu().numpy() y = y[0].squeeze().detach().cpu().numpy()
spec_fake = ap.melspectrogram(y_hat).T spec_fake = ap.melspectrogram(y_hat).T
spec_real = ap.melspectrogram(y).T spec_real = ap.melspectrogram(y).T

View File

@ -0,0 +1,54 @@
import glob
import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
use_espeak_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
test_sentences=[
"Be a voice, not an echo.",
],
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.name ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)