mirror of https://github.com/coqui-ai/TTS.git
Implement VITS model 🚀
VITS model implementation built on Glow TTS and HiFiGAN layers.
This commit is contained in:
parent
060e746e21
commit
c312acac7d
|
@ -0,0 +1,60 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
|
from TTS.tts.models.vits import VitsArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VitsConfig(BaseTTSConfig):
|
||||||
|
"""Defines parameters for VITS End2End TTS model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> from TTS.tts.configs import VitsConfig
|
||||||
|
>>> config = VitsConfig()
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str = "vits"
|
||||||
|
# model specific params
|
||||||
|
model_args: VitsArgs = field(default_factory=VitsArgs)
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
grad_clip: float = field(default_factory=lambda: [5, 5])
|
||||||
|
lr_gen: float = 0.0002
|
||||||
|
lr_disc: float = 0.0002
|
||||||
|
lr_scheduler_gen: str = "ExponentialLR"
|
||||||
|
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
|
||||||
|
lr_scheduler_disc: str = "ExponentialLR"
|
||||||
|
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
|
||||||
|
scheduler_after_epoch: bool = True
|
||||||
|
optimizer: str = "AdamW"
|
||||||
|
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})
|
||||||
|
|
||||||
|
# loss params
|
||||||
|
kl_loss_alpha: float = 1.0
|
||||||
|
disc_loss_alpha: float = 1.0
|
||||||
|
gen_loss_alpha: float = 1.0
|
||||||
|
feat_loss_alpha: float = 1.0
|
||||||
|
mel_loss_alpha: float = 45.0
|
||||||
|
|
||||||
|
# data loader params
|
||||||
|
return_wav: bool = True
|
||||||
|
compute_linear_spec: bool = True
|
||||||
|
|
||||||
|
# overrides
|
||||||
|
min_seq_len: int = 13
|
||||||
|
max_seq_len: int = 200
|
||||||
|
r: int = 1 # DO NOT CHANGE
|
||||||
|
add_blank: bool = True
|
||||||
|
|
||||||
|
# testing
|
||||||
|
test_sentences: List[str] = field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
|
"Be a voice, not an echo.",
|
||||||
|
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||||
|
"This cake is great. It's so delicious and moist.",
|
||||||
|
"Prior to November 22, 1963.",
|
||||||
|
]
|
||||||
|
)
|
|
@ -191,6 +191,7 @@ class TTSDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
text, wav_file, speaker_name = item
|
text, wav_file, speaker_name = item
|
||||||
attn = None
|
attn = None
|
||||||
|
raw_text = text
|
||||||
|
|
||||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||||
|
|
||||||
|
@ -236,6 +237,7 @@ class TTSDataset(Dataset):
|
||||||
return self.load_data(self.rescue_item_idx)
|
return self.load_data(self.rescue_item_idx)
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
|
"raw_text": raw_text,
|
||||||
"text": text,
|
"text": text,
|
||||||
"wav": wav,
|
"wav": wav,
|
||||||
"attn": attn,
|
"attn": attn,
|
||||||
|
@ -360,6 +362,7 @@ class TTSDataset(Dataset):
|
||||||
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
|
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
|
||||||
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
|
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
|
||||||
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
|
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
|
||||||
|
raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing]
|
||||||
|
|
||||||
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
|
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
|
||||||
# get pre-computed d-vectors
|
# get pre-computed d-vectors
|
||||||
|
@ -450,6 +453,7 @@ class TTSDataset(Dataset):
|
||||||
attns = torch.FloatTensor(attns).unsqueeze(1)
|
attns = torch.FloatTensor(attns).unsqueeze(1)
|
||||||
else:
|
else:
|
||||||
attns = None
|
attns = None
|
||||||
|
# TODO: return dictionary
|
||||||
return (
|
return (
|
||||||
text,
|
text,
|
||||||
text_lenghts,
|
text_lenghts,
|
||||||
|
@ -463,6 +467,7 @@ class TTSDataset(Dataset):
|
||||||
speaker_ids,
|
speaker_ids,
|
||||||
attns,
|
attns,
|
||||||
wav_padded,
|
wav_padded,
|
||||||
|
raw_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
|
|
@ -28,6 +28,31 @@ class LayerNorm(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm2(nn.Module):
|
||||||
|
"""Layer norm for the 2nd dimension of the input using torch primitive.
|
||||||
|
Args:
|
||||||
|
channels (int): number of channels (2nd dimension) of the input.
|
||||||
|
eps (float): to prevent 0 division
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- input: (B, C, T)
|
||||||
|
- output: (B, C, T)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.gamma = nn.Parameter(torch.ones(channels))
|
||||||
|
self.beta = nn.Parameter(torch.zeros(channels))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.transpose(1, -1)
|
||||||
|
x = torch.nn.functional.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||||
|
return x.transpose(1, -1)
|
||||||
|
|
||||||
|
|
||||||
class TemporalBatchNorm1d(nn.BatchNorm1d):
|
class TemporalBatchNorm1d(nn.BatchNorm1d):
|
||||||
"""Normalize each channel separately over time and batch."""
|
"""Normalize each channel separately over time and batch."""
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ class DurationPredictor(nn.Module):
|
||||||
dropout_p (float): Dropout rate used after each conv layer.
|
dropout_p (float): Dropout rate used after each conv layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
|
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# class arguments
|
# class arguments
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
@ -33,13 +33,18 @@ class DurationPredictor(nn.Module):
|
||||||
self.norm_2 = LayerNorm(hidden_channels)
|
self.norm_2 = LayerNorm(hidden_channels)
|
||||||
# output layer
|
# output layer
|
||||||
self.proj = nn.Conv1d(hidden_channels, 1, 1)
|
self.proj = nn.Conv1d(hidden_channels, 1, 1)
|
||||||
|
if cond_channels is not None and cond_channels != 0:
|
||||||
|
self.cond = nn.Conv1d(cond_channels, in_channels, 1)
|
||||||
|
|
||||||
def forward(self, x, x_mask):
|
def forward(self, x, x_mask, g=None):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- x: :math:`[B, C, T]`
|
- x: :math:`[B, C, T]`
|
||||||
- x_mask: :math:`[B, 1, T]`
|
- x_mask: :math:`[B, 1, T]`
|
||||||
|
- g: :math:`[B, C, 1]`
|
||||||
"""
|
"""
|
||||||
|
if g is not None:
|
||||||
|
x = x + self.cond(g)
|
||||||
x = self.conv_1(x * x_mask)
|
x = self.conv_1(x * x_mask)
|
||||||
x = torch.relu(x)
|
x = torch.relu(x)
|
||||||
x = self.norm_1(x)
|
x = self.norm_1(x)
|
||||||
|
|
|
@ -16,7 +16,7 @@ class ResidualConv1dLayerNormBlock(nn.Module):
|
||||||
::
|
::
|
||||||
|
|
||||||
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|
||||||
|---------------> conv1d_1x1 -----------------------|
|
|---------------> conv1d_1x1 ------------------|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_channels (int): number of input tensor channels.
|
in_channels (int): number of input tensor channels.
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from TTS.tts.layers.glow_tts.glow import LayerNorm
|
from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
|
||||||
|
|
||||||
|
|
||||||
class RelativePositionMultiHeadAttention(nn.Module):
|
class RelativePositionMultiHeadAttention(nn.Module):
|
||||||
|
@ -271,7 +271,7 @@ class FeedForwardNetwork(nn.Module):
|
||||||
dropout_p (float, optional): dropout rate. Defaults to 0.
|
dropout_p (float, optional): dropout rate. Defaults to 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0):
|
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0, causal=False):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
@ -280,17 +280,46 @@ class FeedForwardNetwork(nn.Module):
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.dropout_p = dropout_p
|
self.dropout_p = dropout_p
|
||||||
|
|
||||||
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
|
if causal:
|
||||||
self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
self.padding = self._causal_padding
|
||||||
|
else:
|
||||||
|
self.padding = self._same_padding
|
||||||
|
|
||||||
|
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size)
|
||||||
|
self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size)
|
||||||
self.dropout = nn.Dropout(dropout_p)
|
self.dropout = nn.Dropout(dropout_p)
|
||||||
|
|
||||||
def forward(self, x, x_mask):
|
def forward(self, x, x_mask):
|
||||||
x = self.conv_1(x * x_mask)
|
x = self.conv_1(self.padding(x * x_mask))
|
||||||
x = torch.relu(x)
|
x = torch.relu(x)
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
x = self.conv_2(x * x_mask)
|
x = self.conv_2(self.padding(x * x_mask))
|
||||||
return x * x_mask
|
return x * x_mask
|
||||||
|
|
||||||
|
def _causal_padding(self, x):
|
||||||
|
if self.kernel_size == 1:
|
||||||
|
return x
|
||||||
|
pad_l = self.kernel_size - 1
|
||||||
|
pad_r = 0
|
||||||
|
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||||
|
x = F.pad(x, self._pad_shape(padding))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _same_padding(self, x):
|
||||||
|
if self.kernel_size == 1:
|
||||||
|
return x
|
||||||
|
pad_l = (self.kernel_size - 1) // 2
|
||||||
|
pad_r = self.kernel_size // 2
|
||||||
|
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||||
|
x = F.pad(x, self._pad_shape(padding))
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pad_shape(padding):
|
||||||
|
l = padding[::-1]
|
||||||
|
pad_shape = [item for sublist in l for item in sublist]
|
||||||
|
return pad_shape
|
||||||
|
|
||||||
|
|
||||||
class RelativePositionTransformer(nn.Module):
|
class RelativePositionTransformer(nn.Module):
|
||||||
"""Transformer with Relative Potional Encoding.
|
"""Transformer with Relative Potional Encoding.
|
||||||
|
@ -310,20 +339,23 @@ class RelativePositionTransformer(nn.Module):
|
||||||
If default, relative encoding is disabled and it is a regular transformer.
|
If default, relative encoding is disabled and it is a regular transformer.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
|
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
|
||||||
|
layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm
|
||||||
|
primitive. Use type "2", type "1: is for backward compat. Defaults to "1".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels,
|
in_channels: int,
|
||||||
out_channels,
|
out_channels: int,
|
||||||
hidden_channels,
|
hidden_channels: int,
|
||||||
hidden_channels_ffn,
|
hidden_channels_ffn: int,
|
||||||
num_heads,
|
num_heads: int,
|
||||||
num_layers,
|
num_layers: int,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
rel_attn_window_size=None,
|
rel_attn_window_size: int = None,
|
||||||
input_length=None,
|
input_length: int = None,
|
||||||
|
layer_norm_type: str = "1",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_channels = hidden_channels
|
self.hidden_channels = hidden_channels
|
||||||
|
@ -351,7 +383,12 @@ class RelativePositionTransformer(nn.Module):
|
||||||
input_length=input_length,
|
input_length=input_length,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if layer_norm_type == "1":
|
||||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||||
|
elif layer_norm_type == "2":
|
||||||
|
self.norm_layers_1.append(LayerNorm2(hidden_channels))
|
||||||
|
else:
|
||||||
|
raise ValueError(" [!] Unknown layer norm type")
|
||||||
|
|
||||||
if hidden_channels != out_channels and (idx + 1) == self.num_layers:
|
if hidden_channels != out_channels and (idx + 1) == self.num_layers:
|
||||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||||
|
@ -366,7 +403,12 @@ class RelativePositionTransformer(nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if layer_norm_type == "1":
|
||||||
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
|
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
|
||||||
|
elif layer_norm_type == "2":
|
||||||
|
self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels))
|
||||||
|
else:
|
||||||
|
raise ValueError(" [!] Unknown layer norm type")
|
||||||
|
|
||||||
def forward(self, x, x_mask):
|
def forward(self, x, x_mask):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -2,11 +2,13 @@ import math
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional
|
from torch.nn import functional
|
||||||
|
|
||||||
from TTS.tts.utils.data import sequence_mask
|
from TTS.tts.utils.data import sequence_mask
|
||||||
from TTS.tts.utils.ssim import ssim
|
from TTS.tts.utils.ssim import ssim
|
||||||
|
from TTS.utils.audio import TorchSTFT
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=abstract-method
|
# pylint: disable=abstract-method
|
||||||
|
@ -514,3 +516,142 @@ class AlignTTSLoss(nn.Module):
|
||||||
+ self.mdn_alpha * mdn_loss
|
+ self.mdn_alpha * mdn_loss
|
||||||
)
|
)
|
||||||
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
|
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
|
||||||
|
|
||||||
|
|
||||||
|
class VitsGeneratorLoss(nn.Module):
|
||||||
|
def __init__(self, c: Coqpit):
|
||||||
|
super().__init__()
|
||||||
|
self.kl_loss_alpha = c.kl_loss_alpha
|
||||||
|
self.gen_loss_alpha = c.gen_loss_alpha
|
||||||
|
self.feat_loss_alpha = c.feat_loss_alpha
|
||||||
|
self.mel_loss_alpha = c.mel_loss_alpha
|
||||||
|
self.stft = TorchSTFT(
|
||||||
|
c.audio.fft_size,
|
||||||
|
c.audio.hop_length,
|
||||||
|
c.audio.win_length,
|
||||||
|
sample_rate=c.audio.sample_rate,
|
||||||
|
mel_fmin=c.audio.mel_fmin,
|
||||||
|
mel_fmax=c.audio.mel_fmax,
|
||||||
|
n_mels=c.audio.num_mels,
|
||||||
|
use_mel=True,
|
||||||
|
do_amp_to_db=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def feature_loss(feats_real, feats_generated):
|
||||||
|
loss = 0
|
||||||
|
for dr, dg in zip(feats_real, feats_generated):
|
||||||
|
for rl, gl in zip(dr, dg):
|
||||||
|
rl = rl.float().detach()
|
||||||
|
gl = gl.float()
|
||||||
|
loss += torch.mean(torch.abs(rl - gl))
|
||||||
|
|
||||||
|
return loss * 2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generator_loss(scores_fake):
|
||||||
|
loss = 0
|
||||||
|
gen_losses = []
|
||||||
|
for dg in scores_fake:
|
||||||
|
dg = dg.float()
|
||||||
|
l = torch.mean((1 - dg) ** 2)
|
||||||
|
gen_losses.append(l)
|
||||||
|
loss += l
|
||||||
|
|
||||||
|
return loss, gen_losses
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
||||||
|
"""
|
||||||
|
z_p, logs_q: [b, h, t_t]
|
||||||
|
m_p, logs_p: [b, h, t_t]
|
||||||
|
"""
|
||||||
|
z_p = z_p.float()
|
||||||
|
logs_q = logs_q.float()
|
||||||
|
m_p = m_p.float()
|
||||||
|
logs_p = logs_p.float()
|
||||||
|
z_mask = z_mask.float()
|
||||||
|
|
||||||
|
kl = logs_p - logs_q - 0.5
|
||||||
|
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
||||||
|
kl = torch.sum(kl * z_mask)
|
||||||
|
l = kl / torch.sum(z_mask)
|
||||||
|
return l
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
waveform,
|
||||||
|
waveform_hat,
|
||||||
|
z_p,
|
||||||
|
logs_q,
|
||||||
|
m_p,
|
||||||
|
logs_p,
|
||||||
|
z_len,
|
||||||
|
scores_disc_fake,
|
||||||
|
feats_disc_fake,
|
||||||
|
feats_disc_real,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- wavefrom: :math:`[B, 1, T]`
|
||||||
|
- waveform_hat: :math:`[B, 1, T]`
|
||||||
|
- z_p: :math:`[B, C, T]`
|
||||||
|
- logs_q: :math:`[B, C, T]`
|
||||||
|
- m_p: :math:`[B, C, T]`
|
||||||
|
- logs_p: :math:`[B, C, T]`
|
||||||
|
- z_len: :math:`[B]`
|
||||||
|
- scores_disc_fake[i]: :math:`[B, C]`
|
||||||
|
- feats_disc_fake[i][j]: :math:`[B, C, T', P]`
|
||||||
|
- feats_disc_real[i][j]: :math:`[B, C, T', P]`
|
||||||
|
"""
|
||||||
|
loss = 0.0
|
||||||
|
return_dict = {}
|
||||||
|
z_mask = sequence_mask(z_len).float()
|
||||||
|
# compute mel spectrograms from the waveforms
|
||||||
|
mel = self.stft(waveform)
|
||||||
|
mel_hat = self.stft(waveform_hat)
|
||||||
|
# compute losses
|
||||||
|
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
|
||||||
|
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
|
||||||
|
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||||
|
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
||||||
|
loss = loss_kl + loss_feat + loss_mel + loss_gen
|
||||||
|
# pass losses to the dict
|
||||||
|
return_dict["loss_gen"] = loss_gen
|
||||||
|
return_dict["loss_kl"] = loss_kl
|
||||||
|
return_dict["loss_feat"] = loss_feat
|
||||||
|
return_dict["loss_mel"] = loss_mel
|
||||||
|
return_dict["loss"] = loss
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
|
class VitsDiscriminatorLoss(nn.Module):
|
||||||
|
def __init__(self, c: Coqpit):
|
||||||
|
super().__init__()
|
||||||
|
self.disc_loss_alpha = c.disc_loss_alpha
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def discriminator_loss(scores_real, scores_fake):
|
||||||
|
loss = 0
|
||||||
|
real_losses = []
|
||||||
|
fake_losses = []
|
||||||
|
for dr, dg in zip(scores_real, scores_fake):
|
||||||
|
dr = dr.float()
|
||||||
|
dg = dg.float()
|
||||||
|
real_loss = torch.mean((1 - dr) ** 2)
|
||||||
|
fake_loss = torch.mean(dg ** 2)
|
||||||
|
loss += real_loss + fake_loss
|
||||||
|
real_losses.append(real_loss.item())
|
||||||
|
fake_losses.append(fake_loss.item())
|
||||||
|
|
||||||
|
return loss, real_losses, fake_losses
|
||||||
|
|
||||||
|
def forward(self, scores_disc_real, scores_disc_fake):
|
||||||
|
loss = 0.0
|
||||||
|
return_dict = {}
|
||||||
|
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
|
||||||
|
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
|
||||||
|
loss = loss + loss_disc
|
||||||
|
return_dict["loss_disc"] = loss_disc
|
||||||
|
return_dict["loss"] = loss
|
||||||
|
return return_dict
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.modules.conv import Conv1d
|
||||||
|
|
||||||
|
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorS(torch.nn.Module):
|
||||||
|
"""HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, use_spectral_norm=False):
|
||||||
|
super().__init__()
|
||||||
|
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||||
|
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||||
|
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||||
|
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||||
|
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||||
|
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): input waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: discriminator scores.
|
||||||
|
List[Tensor]: list of features from the convolutiona layers.
|
||||||
|
"""
|
||||||
|
feat = []
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = torch.nn.functional.leaky_relu(x, 0.1)
|
||||||
|
feat.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
feat.append(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
return x, feat
|
||||||
|
|
||||||
|
|
||||||
|
class VitsDiscriminator(nn.Module):
|
||||||
|
"""VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator.
|
||||||
|
|
||||||
|
::
|
||||||
|
waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats
|
||||||
|
|--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, use_spectral_norm=False):
|
||||||
|
super().__init__()
|
||||||
|
self.sd = DiscriminatorS(use_spectral_norm=use_spectral_norm)
|
||||||
|
self.mpd = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): input waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tensor]: discriminator scores.
|
||||||
|
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||||
|
"""
|
||||||
|
scores, feats = self.mpd(x)
|
||||||
|
score_sd, feats_sd = self.sd(x)
|
||||||
|
return scores + [score_sd], feats + [feats_sd]
|
|
@ -0,0 +1,271 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from TTS.tts.layers.glow_tts.glow import WN
|
||||||
|
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||||
|
from TTS.tts.utils.data import sequence_mask
|
||||||
|
|
||||||
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pad_shape(pad_shape):
|
||||||
|
l = pad_shape[::-1]
|
||||||
|
pad_shape = [item for sublist in l for item in sublist]
|
||||||
|
return pad_shape
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("Conv") != -1:
|
||||||
|
m.weight.data.normal_(mean, std)
|
||||||
|
|
||||||
|
|
||||||
|
def get_padding(kernel_size, dilation=1):
|
||||||
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_vocab: int,
|
||||||
|
out_channels: int,
|
||||||
|
hidden_channels: int,
|
||||||
|
hidden_channels_ffn: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_layers: int,
|
||||||
|
kernel_size: int,
|
||||||
|
dropout_p: float,
|
||||||
|
):
|
||||||
|
"""Text Encoder for VITS model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_vocab (int): Number of characters for the embedding layer.
|
||||||
|
out_channels (int): Number of channels for the output.
|
||||||
|
hidden_channels (int): Number of channels for the hidden layers.
|
||||||
|
hidden_channels_ffn (int): Number of channels for the convolutional layers.
|
||||||
|
num_heads (int): Number of attention heads for the Transformer layers.
|
||||||
|
num_layers (int): Number of Transformer layers.
|
||||||
|
kernel_size (int): Kernel size for the FFN layers in Transformer network.
|
||||||
|
dropout_p (float): Dropout rate for the Transformer layers.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
|
||||||
|
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
||||||
|
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
|
||||||
|
|
||||||
|
self.encoder = RelativePositionTransformer(
|
||||||
|
in_channels=hidden_channels,
|
||||||
|
out_channels=hidden_channels,
|
||||||
|
hidden_channels=hidden_channels,
|
||||||
|
hidden_channels_ffn=hidden_channels_ffn,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_layers=num_layers,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
layer_norm_type="2",
|
||||||
|
rel_attn_window_size=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
|
def forward(self, x, x_lengths):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, T]`
|
||||||
|
- x_length: :math:`[B]`
|
||||||
|
"""
|
||||||
|
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
||||||
|
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||||
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||||
|
|
||||||
|
x = self.encoder(x * x_mask, x_mask)
|
||||||
|
stats = self.proj(x) * x_mask
|
||||||
|
|
||||||
|
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||||
|
return x, m, logs, x_mask
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualCouplingBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation_rate,
|
||||||
|
num_layers,
|
||||||
|
dropout_p=0,
|
||||||
|
cond_channels=0,
|
||||||
|
mean_only=False,
|
||||||
|
):
|
||||||
|
assert channels % 2 == 0, "channels should be divisible by 2"
|
||||||
|
super().__init__()
|
||||||
|
self.half_channels = channels // 2
|
||||||
|
self.mean_only = mean_only
|
||||||
|
# input layer
|
||||||
|
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||||
|
# coupling layers
|
||||||
|
self.enc = WN(
|
||||||
|
hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation_rate,
|
||||||
|
num_layers,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
c_in_channels=cond_channels,
|
||||||
|
)
|
||||||
|
# output layer
|
||||||
|
# Initializing last layer to 0 makes the affine coupling layers
|
||||||
|
# do nothing at first. This helps with training stability
|
||||||
|
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
||||||
|
self.post.weight.data.zero_()
|
||||||
|
self.post.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None, reverse=False):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, C, T]`
|
||||||
|
- x_mask: :math:`[B, 1, T]`
|
||||||
|
- g: :math:`[B, C, 1]`
|
||||||
|
"""
|
||||||
|
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
||||||
|
h = self.pre(x0) * x_mask
|
||||||
|
h = self.enc(h, x_mask, g=g)
|
||||||
|
stats = self.post(h) * x_mask
|
||||||
|
if not self.mean_only:
|
||||||
|
m, log_scale = torch.split(stats, [self.half_channels] * 2, 1)
|
||||||
|
else:
|
||||||
|
m = stats
|
||||||
|
log_scale = torch.zeros_like(m)
|
||||||
|
|
||||||
|
if not reverse:
|
||||||
|
x1 = m + x1 * torch.exp(log_scale) * x_mask
|
||||||
|
x = torch.cat([x0, x1], 1)
|
||||||
|
logdet = torch.sum(log_scale, [1, 2])
|
||||||
|
return x, logdet
|
||||||
|
else:
|
||||||
|
x1 = (x1 - m) * torch.exp(-log_scale) * x_mask
|
||||||
|
x = torch.cat([x0, x1], 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualCouplingBlocks(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
hidden_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
dilation_rate: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_flows=4,
|
||||||
|
cond_channels=0,
|
||||||
|
):
|
||||||
|
"""Redisual Coupling blocks for VITS flow layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): Number of input and output tensor channels.
|
||||||
|
hidden_channels (int): Number of hidden network channels.
|
||||||
|
kernel_size (int): Kernel size of the WaveNet layers.
|
||||||
|
dilation_rate (int): Dilation rate of the WaveNet layers.
|
||||||
|
num_layers (int): Number of the WaveNet layers.
|
||||||
|
num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4.
|
||||||
|
cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.dilation_rate = dilation_rate
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_flows = num_flows
|
||||||
|
self.cond_channels = cond_channels
|
||||||
|
|
||||||
|
self.flows = nn.ModuleList()
|
||||||
|
for _ in range(num_flows):
|
||||||
|
self.flows.append(
|
||||||
|
ResidualCouplingBlock(
|
||||||
|
channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation_rate,
|
||||||
|
num_layers,
|
||||||
|
cond_channels=cond_channels,
|
||||||
|
mean_only=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None, reverse=False):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, C, T]`
|
||||||
|
- x_mask: :math:`[B, 1, T]`
|
||||||
|
- g: :math:`[B, C, 1]`
|
||||||
|
"""
|
||||||
|
if not reverse:
|
||||||
|
for flow in self.flows:
|
||||||
|
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||||
|
x = torch.flip(x, [1])
|
||||||
|
else:
|
||||||
|
for flow in reversed(self.flows):
|
||||||
|
x = torch.flip(x, [1])
|
||||||
|
x = flow(x, x_mask, g=g, reverse=reverse)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PosteriorEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
hidden_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
dilation_rate: int,
|
||||||
|
num_layers: int,
|
||||||
|
cond_channels=0,
|
||||||
|
):
|
||||||
|
"""Posterior Encoder of VITS model.
|
||||||
|
|
||||||
|
::
|
||||||
|
x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input tensor channels.
|
||||||
|
out_channels (int): Number of output tensor channels.
|
||||||
|
hidden_channels (int): Number of hidden channels.
|
||||||
|
kernel_size (int): Kernel size of the WaveNet convolution layers.
|
||||||
|
dilation_rate (int): Dilation rate of the WaveNet layers.
|
||||||
|
num_layers (int): Number of the WaveNet layers.
|
||||||
|
cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.dilation_rate = dilation_rate
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.cond_channels = cond_channels
|
||||||
|
|
||||||
|
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||||
|
self.enc = WN(
|
||||||
|
hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels
|
||||||
|
)
|
||||||
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
|
def forward(self, x, x_lengths, g=None):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, C, T]`
|
||||||
|
- x_lengths: :math:`[B, 1]`
|
||||||
|
- g: :math:`[B, C, 1]`
|
||||||
|
"""
|
||||||
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||||
|
x = self.pre(x) * x_mask
|
||||||
|
x = self.enc(x, x_mask, g=g)
|
||||||
|
stats = self.proj(x) * x_mask
|
||||||
|
mean, log_scale = torch.split(stats, self.out_channels, dim=1)
|
||||||
|
z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask
|
||||||
|
return z, mean, log_scale, x_mask
|
|
@ -0,0 +1,276 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from TTS.tts.layers.generic.normalization import LayerNorm2
|
||||||
|
from TTS.tts.layers.vits.transforms import piecewise_rational_quadratic_transform
|
||||||
|
|
||||||
|
|
||||||
|
class DilatedDepthSeparableConv(nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size, num_layers, dropout_p=0.0) -> torch.tensor:
|
||||||
|
"""Dilated Depth-wise Separable Convolution module.
|
||||||
|
|
||||||
|
::
|
||||||
|
x |-> DDSConv(x) -> LayerNorm(x) -> GeLU(x) -> Conv1x1(x) -> LayerNorm(x) -> GeLU(x) -> + -> o
|
||||||
|
|-------------------------------------------------------------------------------------^
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels ([type]): [description]
|
||||||
|
kernel_size ([type]): [description]
|
||||||
|
num_layers ([type]): [description]
|
||||||
|
dropout_p (float, optional): [description]. Defaults to 0.0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.tensor: Network output masked by the input sequence mask.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
self.convs_sep = nn.ModuleList()
|
||||||
|
self.convs_1x1 = nn.ModuleList()
|
||||||
|
self.norms_1 = nn.ModuleList()
|
||||||
|
self.norms_2 = nn.ModuleList()
|
||||||
|
for i in range(num_layers):
|
||||||
|
dilation = kernel_size ** i
|
||||||
|
padding = (kernel_size * dilation - dilation) // 2
|
||||||
|
self.convs_sep.append(
|
||||||
|
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
|
||||||
|
)
|
||||||
|
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
||||||
|
self.norms_1.append(LayerNorm2(channels))
|
||||||
|
self.norms_2.append(LayerNorm2(channels))
|
||||||
|
self.dropout = nn.Dropout(dropout_p)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, C, T]`
|
||||||
|
- x_mask: :math:`[B, 1, T]`
|
||||||
|
"""
|
||||||
|
if g is not None:
|
||||||
|
x = x + g
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
y = self.convs_sep[i](x * x_mask)
|
||||||
|
y = self.norms_1[i](y)
|
||||||
|
y = F.gelu(y)
|
||||||
|
y = self.convs_1x1[i](y)
|
||||||
|
y = self.norms_2[i](y)
|
||||||
|
y = F.gelu(y)
|
||||||
|
y = self.dropout(y)
|
||||||
|
x = x + y
|
||||||
|
return x * x_mask
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseAffine(nn.Module):
|
||||||
|
"""Element-wise affine transform like no-population stats BatchNorm alternative.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): Number of input tensor channels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels):
|
||||||
|
super().__init__()
|
||||||
|
self.translation = nn.Parameter(torch.zeros(channels, 1))
|
||||||
|
self.log_scale = nn.Parameter(torch.zeros(channels, 1))
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, reverse=False, **kwargs): # pylint: disable=unused-argument
|
||||||
|
if not reverse:
|
||||||
|
y = (x * torch.exp(self.log_scale) + self.translation) * x_mask
|
||||||
|
logdet = torch.sum(self.log_scale * x_mask, [1, 2])
|
||||||
|
return y, logdet
|
||||||
|
x = (x - self.translation) * torch.exp(-self.log_scale) * x_mask
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConvFlow(nn.Module):
|
||||||
|
"""Dilated depth separable convolutional based spline flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input tensor channels.
|
||||||
|
hidden_channels (int): Number of in network channels.
|
||||||
|
kernel_size (int): Convolutional kernel size.
|
||||||
|
num_layers (int): Number of convolutional layers.
|
||||||
|
num_bins (int, optional): Number of spline bins. Defaults to 10.
|
||||||
|
tail_bound (float, optional): Tail bound for PRQT. Defaults to 5.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
hidden_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_bins=10,
|
||||||
|
tail_bound=5.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_bins = num_bins
|
||||||
|
self.tail_bound = tail_bound
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.half_channels = in_channels // 2
|
||||||
|
|
||||||
|
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||||
|
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers, dropout_p=0.0)
|
||||||
|
self.proj = nn.Conv1d(hidden_channels, self.half_channels * (num_bins * 3 - 1), 1)
|
||||||
|
self.proj.weight.data.zero_()
|
||||||
|
self.proj.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None, reverse=False):
|
||||||
|
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
||||||
|
h = self.pre(x0)
|
||||||
|
h = self.convs(h, x_mask, g=g)
|
||||||
|
h = self.proj(h) * x_mask
|
||||||
|
|
||||||
|
b, c, t = x0.shape
|
||||||
|
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
||||||
|
|
||||||
|
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.hidden_channels)
|
||||||
|
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.hidden_channels)
|
||||||
|
unnormalized_derivatives = h[..., 2 * self.num_bins :]
|
||||||
|
|
||||||
|
x1, logabsdet = piecewise_rational_quadratic_transform(
|
||||||
|
x1,
|
||||||
|
unnormalized_widths,
|
||||||
|
unnormalized_heights,
|
||||||
|
unnormalized_derivatives,
|
||||||
|
inverse=reverse,
|
||||||
|
tails="linear",
|
||||||
|
tail_bound=self.tail_bound,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.cat([x0, x1], 1) * x_mask
|
||||||
|
logdet = torch.sum(logabsdet * x_mask, [1, 2])
|
||||||
|
if not reverse:
|
||||||
|
return x, logdet
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class StochasticDurationPredictor(nn.Module):
|
||||||
|
"""Stochastic duration predictor with Spline Flows.
|
||||||
|
|
||||||
|
It applies Variational Dequantization and Variationsl Data Augmentation.
|
||||||
|
|
||||||
|
Paper:
|
||||||
|
SDP: https://arxiv.org/pdf/2106.06103.pdf
|
||||||
|
Spline Flow: https://arxiv.org/abs/1906.04032
|
||||||
|
|
||||||
|
::
|
||||||
|
## Inference
|
||||||
|
|
||||||
|
x -> TextCondEncoder() -> Flow() -> dr_hat
|
||||||
|
noise ----------------------^
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|---------------------|
|
||||||
|
x -> TextCondEncoder() -> + -> PosteriorEncoder() -> split() -> z_u, z_v -> (d - z_u) -> concat() -> Flow() -> noise
|
||||||
|
d -> DurCondEncoder() -> ^ |
|
||||||
|
|------------------------------------------------------------------------------|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input tensor channels.
|
||||||
|
hidden_channels (int): Number of hidden channels.
|
||||||
|
kernel_size (int): Kernel size of convolutional layers.
|
||||||
|
dropout_p (float): Dropout rate.
|
||||||
|
num_flows (int, optional): Number of flow blocks. Defaults to 4.
|
||||||
|
cond_channels (int, optional): Number of channels of conditioning tensor. Defaults to 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# condition encoder text
|
||||||
|
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||||
|
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
|
||||||
|
self.proj = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||||
|
|
||||||
|
# posterior encoder
|
||||||
|
self.flows = nn.ModuleList()
|
||||||
|
self.flows.append(ElementwiseAffine(2))
|
||||||
|
self.flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)]
|
||||||
|
|
||||||
|
# condition encoder duration
|
||||||
|
self.post_pre = nn.Conv1d(1, hidden_channels, 1)
|
||||||
|
self.post_convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
|
||||||
|
self.post_proj = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||||
|
|
||||||
|
# flow layers
|
||||||
|
self.post_flows = nn.ModuleList()
|
||||||
|
self.post_flows.append(ElementwiseAffine(2))
|
||||||
|
self.post_flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)]
|
||||||
|
|
||||||
|
if cond_channels != 0 and cond_channels is not None:
|
||||||
|
self.cond = nn.Conv1d(cond_channels, hidden_channels, 1)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, C, T]`
|
||||||
|
- x_mask: :math:`[B, 1, T]`
|
||||||
|
- dr: :math:`[B, 1, T]`
|
||||||
|
- g: :math:`[B, C]`
|
||||||
|
"""
|
||||||
|
# condition encoder text
|
||||||
|
x = self.pre(x)
|
||||||
|
if g is not None:
|
||||||
|
x = x + self.cond(g)
|
||||||
|
x = self.convs(x, x_mask)
|
||||||
|
x = self.proj(x) * x_mask
|
||||||
|
|
||||||
|
if not reverse:
|
||||||
|
flows = self.flows
|
||||||
|
assert dr is not None
|
||||||
|
|
||||||
|
# condition encoder duration
|
||||||
|
h = self.post_pre(dr)
|
||||||
|
h = self.post_convs(h, x_mask)
|
||||||
|
h = self.post_proj(h) * x_mask
|
||||||
|
noise = torch.rand(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||||
|
z_q = noise
|
||||||
|
|
||||||
|
# posterior encoder
|
||||||
|
logdet_tot_q = 0.0
|
||||||
|
for idx, flow in enumerate(self.post_flows):
|
||||||
|
z_q, logdet_q = flow(z_q, x_mask, g=(x + h))
|
||||||
|
logdet_tot_q = logdet_tot_q + logdet_q
|
||||||
|
if idx > 0:
|
||||||
|
z_q = torch.flip(z_q, [1])
|
||||||
|
|
||||||
|
z_u, z_v = torch.split(z_q, [1, 1], 1)
|
||||||
|
u = torch.sigmoid(z_u) * x_mask
|
||||||
|
z0 = (dr - u) * x_mask
|
||||||
|
|
||||||
|
# posterior encoder - neg log likelihood
|
||||||
|
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
|
||||||
|
nll_posterior_encoder = (
|
||||||
|
torch.sum(-0.5 * (math.log(2 * math.pi) + (noise ** 2)) * x_mask, [1, 2]) - logdet_tot_q
|
||||||
|
)
|
||||||
|
|
||||||
|
z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask
|
||||||
|
logdet_tot = torch.sum(-z0, [1, 2])
|
||||||
|
z = torch.cat([z0, z_v], 1)
|
||||||
|
|
||||||
|
# flow layers
|
||||||
|
for idx, flow in enumerate(flows):
|
||||||
|
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||||
|
logdet_tot = logdet_tot + logdet
|
||||||
|
if idx > 0:
|
||||||
|
z = torch.flip(z, [1])
|
||||||
|
|
||||||
|
# flow layers - neg log likelihood
|
||||||
|
nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
|
||||||
|
return nll_flow_layers + nll_posterior_encoder
|
||||||
|
|
||||||
|
flows = list(reversed(self.flows))
|
||||||
|
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||||
|
z = torch.rand(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
||||||
|
for flow in flows:
|
||||||
|
z = torch.flip(z, [1])
|
||||||
|
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||||
|
|
||||||
|
z0, _ = torch.split(z, [1, 1], 1)
|
||||||
|
logw = z0
|
||||||
|
return logw
|
|
@ -0,0 +1,203 @@
|
||||||
|
# adopted from https://github.com/bayesiains/nflows
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
||||||
|
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
||||||
|
DEFAULT_MIN_DERIVATIVE = 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
def piecewise_rational_quadratic_transform(
|
||||||
|
inputs,
|
||||||
|
unnormalized_widths,
|
||||||
|
unnormalized_heights,
|
||||||
|
unnormalized_derivatives,
|
||||||
|
inverse=False,
|
||||||
|
tails=None,
|
||||||
|
tail_bound=1.0,
|
||||||
|
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||||
|
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||||
|
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||||
|
):
|
||||||
|
|
||||||
|
if tails is None:
|
||||||
|
spline_fn = rational_quadratic_spline
|
||||||
|
spline_kwargs = {}
|
||||||
|
else:
|
||||||
|
spline_fn = unconstrained_rational_quadratic_spline
|
||||||
|
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
||||||
|
|
||||||
|
outputs, logabsdet = spline_fn(
|
||||||
|
inputs=inputs,
|
||||||
|
unnormalized_widths=unnormalized_widths,
|
||||||
|
unnormalized_heights=unnormalized_heights,
|
||||||
|
unnormalized_derivatives=unnormalized_derivatives,
|
||||||
|
inverse=inverse,
|
||||||
|
min_bin_width=min_bin_width,
|
||||||
|
min_bin_height=min_bin_height,
|
||||||
|
min_derivative=min_derivative,
|
||||||
|
**spline_kwargs,
|
||||||
|
)
|
||||||
|
return outputs, logabsdet
|
||||||
|
|
||||||
|
|
||||||
|
def searchsorted(bin_locations, inputs, eps=1e-6):
|
||||||
|
bin_locations[..., -1] += eps
|
||||||
|
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
||||||
|
|
||||||
|
|
||||||
|
def unconstrained_rational_quadratic_spline(
|
||||||
|
inputs,
|
||||||
|
unnormalized_widths,
|
||||||
|
unnormalized_heights,
|
||||||
|
unnormalized_derivatives,
|
||||||
|
inverse=False,
|
||||||
|
tails="linear",
|
||||||
|
tail_bound=1.0,
|
||||||
|
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||||
|
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||||
|
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||||
|
):
|
||||||
|
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
||||||
|
outside_interval_mask = ~inside_interval_mask
|
||||||
|
|
||||||
|
outputs = torch.zeros_like(inputs)
|
||||||
|
logabsdet = torch.zeros_like(inputs)
|
||||||
|
|
||||||
|
if tails == "linear":
|
||||||
|
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
||||||
|
constant = np.log(np.exp(1 - min_derivative) - 1)
|
||||||
|
unnormalized_derivatives[..., 0] = constant
|
||||||
|
unnormalized_derivatives[..., -1] = constant
|
||||||
|
|
||||||
|
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
||||||
|
logabsdet[outside_interval_mask] = 0
|
||||||
|
else:
|
||||||
|
raise RuntimeError("{} tails are not implemented.".format(tails))
|
||||||
|
|
||||||
|
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
|
||||||
|
inputs=inputs[inside_interval_mask],
|
||||||
|
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
||||||
|
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
||||||
|
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
||||||
|
inverse=inverse,
|
||||||
|
left=-tail_bound,
|
||||||
|
right=tail_bound,
|
||||||
|
bottom=-tail_bound,
|
||||||
|
top=tail_bound,
|
||||||
|
min_bin_width=min_bin_width,
|
||||||
|
min_bin_height=min_bin_height,
|
||||||
|
min_derivative=min_derivative,
|
||||||
|
)
|
||||||
|
|
||||||
|
return outputs, logabsdet
|
||||||
|
|
||||||
|
|
||||||
|
def rational_quadratic_spline(
|
||||||
|
inputs,
|
||||||
|
unnormalized_widths,
|
||||||
|
unnormalized_heights,
|
||||||
|
unnormalized_derivatives,
|
||||||
|
inverse=False,
|
||||||
|
left=0.0,
|
||||||
|
right=1.0,
|
||||||
|
bottom=0.0,
|
||||||
|
top=1.0,
|
||||||
|
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||||
|
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||||
|
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||||
|
):
|
||||||
|
if torch.min(inputs) < left or torch.max(inputs) > right:
|
||||||
|
raise ValueError("Input to a transform is not within its domain")
|
||||||
|
|
||||||
|
num_bins = unnormalized_widths.shape[-1]
|
||||||
|
|
||||||
|
if min_bin_width * num_bins > 1.0:
|
||||||
|
raise ValueError("Minimal bin width too large for the number of bins")
|
||||||
|
if min_bin_height * num_bins > 1.0:
|
||||||
|
raise ValueError("Minimal bin height too large for the number of bins")
|
||||||
|
|
||||||
|
widths = F.softmax(unnormalized_widths, dim=-1)
|
||||||
|
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
||||||
|
cumwidths = torch.cumsum(widths, dim=-1)
|
||||||
|
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
||||||
|
cumwidths = (right - left) * cumwidths + left
|
||||||
|
cumwidths[..., 0] = left
|
||||||
|
cumwidths[..., -1] = right
|
||||||
|
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
||||||
|
|
||||||
|
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
||||||
|
|
||||||
|
heights = F.softmax(unnormalized_heights, dim=-1)
|
||||||
|
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
||||||
|
cumheights = torch.cumsum(heights, dim=-1)
|
||||||
|
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
||||||
|
cumheights = (top - bottom) * cumheights + bottom
|
||||||
|
cumheights[..., 0] = bottom
|
||||||
|
cumheights[..., -1] = top
|
||||||
|
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
||||||
|
|
||||||
|
if inverse:
|
||||||
|
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
||||||
|
else:
|
||||||
|
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
||||||
|
|
||||||
|
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
||||||
|
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
||||||
|
|
||||||
|
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
||||||
|
delta = heights / widths
|
||||||
|
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
||||||
|
|
||||||
|
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
||||||
|
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
||||||
|
|
||||||
|
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
||||||
|
|
||||||
|
if inverse:
|
||||||
|
a = (inputs - input_cumheights) * (
|
||||||
|
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||||
|
) + input_heights * (input_delta - input_derivatives)
|
||||||
|
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
||||||
|
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||||
|
)
|
||||||
|
c = -input_delta * (inputs - input_cumheights)
|
||||||
|
|
||||||
|
discriminant = b.pow(2) - 4 * a * c
|
||||||
|
assert (discriminant >= 0).all()
|
||||||
|
|
||||||
|
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
||||||
|
outputs = root * input_bin_widths + input_cumwidths
|
||||||
|
|
||||||
|
theta_one_minus_theta = root * (1 - root)
|
||||||
|
denominator = input_delta + (
|
||||||
|
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
||||||
|
)
|
||||||
|
derivative_numerator = input_delta.pow(2) * (
|
||||||
|
input_derivatives_plus_one * root.pow(2)
|
||||||
|
+ 2 * input_delta * theta_one_minus_theta
|
||||||
|
+ input_derivatives * (1 - root).pow(2)
|
||||||
|
)
|
||||||
|
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||||
|
|
||||||
|
return outputs, -logabsdet
|
||||||
|
else:
|
||||||
|
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||||
|
theta_one_minus_theta = theta * (1 - theta)
|
||||||
|
|
||||||
|
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
|
||||||
|
denominator = input_delta + (
|
||||||
|
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
||||||
|
)
|
||||||
|
outputs = input_cumheights + numerator / denominator
|
||||||
|
|
||||||
|
derivative_numerator = input_delta.pow(2) * (
|
||||||
|
input_derivatives_plus_one * theta.pow(2)
|
||||||
|
+ 2 * input_delta * theta_one_minus_theta
|
||||||
|
+ input_derivatives * (1 - theta).pow(2)
|
||||||
|
)
|
||||||
|
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||||
|
|
||||||
|
return outputs, logabsdet
|
|
@ -212,13 +212,22 @@ class BaseTTS(BaseModel):
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if config.use_phonemes and config.compute_input_seq_cache:
|
||||||
config.use_phonemes
|
if hasattr(self, "eval_data_items") and is_eval:
|
||||||
and config.compute_input_seq_cache
|
dataset.items = self.eval_data_items
|
||||||
and not os.path.exists(dataset.phoneme_cache_path)
|
elif hasattr(self, "train_data_items") and not is_eval:
|
||||||
):
|
dataset.items = self.train_data_items
|
||||||
|
else:
|
||||||
# precompute phonemes to have a better estimate of sequence lengths.
|
# precompute phonemes to have a better estimate of sequence lengths.
|
||||||
dataset.compute_input_seq(config.num_loader_workers)
|
dataset.compute_input_seq(config.num_loader_workers)
|
||||||
|
|
||||||
|
# TODO: find a more efficient solution
|
||||||
|
# cheap hack - store items in the model state to avoid recomputing when reinit the dataset
|
||||||
|
if is_eval:
|
||||||
|
self.eval_data_items = dataset.items
|
||||||
|
else:
|
||||||
|
self.train_data_items = dataset.items
|
||||||
|
|
||||||
dataset.sort_items()
|
dataset.sort_items()
|
||||||
|
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
|
|
@ -0,0 +1,758 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from coqpit import Coqpit
|
||||||
|
from torch import nn
|
||||||
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
|
|
||||||
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
|
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||||
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
|
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||||
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||||
|
|
||||||
|
# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor
|
||||||
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
|
from TTS.tts.utils.data import sequence_mask
|
||||||
|
from TTS.tts.utils.speakers import get_speaker_manager
|
||||||
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
|
from TTS.tts.utils.visual import plot_alignment
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
|
||||||
|
|
||||||
|
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4):
|
||||||
|
"""Segment each sample in a batch based on the provided segment indices"""
|
||||||
|
segments = torch.zeros_like(x[:, :, :segment_size])
|
||||||
|
for i in range(x.size(0)):
|
||||||
|
index_start = segment_indices[i]
|
||||||
|
index_end = index_start + segment_size
|
||||||
|
segments[i] = x[i, :, index_start:index_end]
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def rand_segment(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4):
|
||||||
|
"""Create random segments based on the input lengths."""
|
||||||
|
B, _, T = x.size()
|
||||||
|
if x_lengths is None:
|
||||||
|
x_lengths = T
|
||||||
|
max_idxs = x_lengths - segment_size + 1
|
||||||
|
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
|
||||||
|
ids_str = (torch.rand([B]).type_as(x) * max_idxs).long()
|
||||||
|
ret = segment(x, ids_str, segment_size)
|
||||||
|
return ret, ids_str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VitsArgs(Coqpit):
|
||||||
|
"""VITS model arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
num_chars (int):
|
||||||
|
Number of characters in the vocabulary. Defaults to 100.
|
||||||
|
|
||||||
|
out_channels (int):
|
||||||
|
Number of output channels. Defaults to 513.
|
||||||
|
|
||||||
|
spec_segment_size (int):
|
||||||
|
Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`.
|
||||||
|
|
||||||
|
hidden_channels (int):
|
||||||
|
Number of hidden channels of the model. Defaults to 192.
|
||||||
|
|
||||||
|
hidden_channels_ffn_text_encoder (int):
|
||||||
|
Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256.
|
||||||
|
|
||||||
|
num_heads_text_encoder (int):
|
||||||
|
Number of attention heads of the text encoder transformer. Defaults to 2.
|
||||||
|
|
||||||
|
num_layers_text_encoder (int):
|
||||||
|
Number of transformer layers in the text encoder. Defaults to 6.
|
||||||
|
|
||||||
|
kernel_size_text_encoder (int):
|
||||||
|
Kernel size of the text encoder transformer FFN layers. Defaults to 3.
|
||||||
|
|
||||||
|
dropout_p_text_encoder (float):
|
||||||
|
Dropout rate of the text encoder. Defaults to 0.1.
|
||||||
|
|
||||||
|
kernel_size_posterior_encoder (int):
|
||||||
|
Kernel size of the posterior encoder's WaveNet layers. Defaults to 5.
|
||||||
|
|
||||||
|
dilatation_posterior_encoder (int):
|
||||||
|
Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1.
|
||||||
|
|
||||||
|
num_layers_posterior_encoder (int):
|
||||||
|
Number of posterior encoder's WaveNet layers. Defaults to 16.
|
||||||
|
|
||||||
|
kernel_size_flow (int):
|
||||||
|
Kernel size of the Residual Coupling layers of the flow network. Defaults to 5.
|
||||||
|
|
||||||
|
dilatation_flow (int):
|
||||||
|
Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1.
|
||||||
|
|
||||||
|
num_layers_flow (int):
|
||||||
|
Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6.
|
||||||
|
|
||||||
|
resblock_type_decoder (str):
|
||||||
|
Type of the residual block in the decoder network. Defaults to "1".
|
||||||
|
|
||||||
|
resblock_kernel_sizes_decoder (List[int]):
|
||||||
|
Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`.
|
||||||
|
|
||||||
|
resblock_dilation_sizes_decoder (List[List[int]]):
|
||||||
|
Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`.
|
||||||
|
|
||||||
|
upsample_rates_decoder (List[int]):
|
||||||
|
Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these
|
||||||
|
values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`.
|
||||||
|
|
||||||
|
upsample_initial_channel_decoder (int):
|
||||||
|
Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512.
|
||||||
|
|
||||||
|
upsample_kernel_sizes_decoder (List[int]):
|
||||||
|
Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`.
|
||||||
|
|
||||||
|
use_sdp (int):
|
||||||
|
Use Stochastic Duration Predictor. Defaults to True.
|
||||||
|
|
||||||
|
noise_scale (float):
|
||||||
|
Noise scale used for the sample noise tensor in training. Defaults to 1.0.
|
||||||
|
|
||||||
|
inference_noise_scale (float):
|
||||||
|
Noise scale used for the sample noise tensor in inference. Defaults to 0.667.
|
||||||
|
|
||||||
|
length_scale (int):
|
||||||
|
Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1.
|
||||||
|
|
||||||
|
noise_scale_dp (float):
|
||||||
|
Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0.
|
||||||
|
|
||||||
|
inference_noise_scale_dp (float):
|
||||||
|
Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8.
|
||||||
|
|
||||||
|
max_inference_len (int):
|
||||||
|
Maximum inference length to limit the memory use. Defaults to None.
|
||||||
|
|
||||||
|
init_discriminator (bool):
|
||||||
|
Initialize the disciminator network if set True. Set False for inference. Defaults to True.
|
||||||
|
|
||||||
|
use_spectral_norm_disriminator (bool):
|
||||||
|
Use spectral normalization over weight norm in the discriminator. Defaults to False.
|
||||||
|
|
||||||
|
use_speaker_embedding (bool):
|
||||||
|
Enable/Disable speaker embedding for multi-speaker models. Defaults to False.
|
||||||
|
|
||||||
|
num_speakers (int):
|
||||||
|
Number of speakers for the speaker embedding layer. Defaults to 0.
|
||||||
|
|
||||||
|
speakers_file (str):
|
||||||
|
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
|
||||||
|
|
||||||
|
speaker_embedding_channels (int):
|
||||||
|
Number of speaker embedding channels. Defaults to 256.
|
||||||
|
|
||||||
|
use_d_vector_file (bool):
|
||||||
|
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
|
||||||
|
|
||||||
|
d_vector_dim (int):
|
||||||
|
Number of d-vector channels. Defaults to 0.
|
||||||
|
|
||||||
|
detach_dp_input (bool):
|
||||||
|
Detach duration predictor's input from the network for stopping the gradients. Defaults to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_chars: int = 100
|
||||||
|
out_channels: int = 513
|
||||||
|
spec_segment_size: int = 32
|
||||||
|
hidden_channels: int = 192
|
||||||
|
hidden_channels_ffn_text_encoder: int = 768
|
||||||
|
num_heads_text_encoder: int = 2
|
||||||
|
num_layers_text_encoder: int = 6
|
||||||
|
kernel_size_text_encoder: int = 3
|
||||||
|
dropout_p_text_encoder: int = 0.1
|
||||||
|
kernel_size_posterior_encoder: int = 5
|
||||||
|
dilation_rate_posterior_encoder: int = 1
|
||||||
|
num_layers_posterior_encoder: int = 16
|
||||||
|
kernel_size_flow: int = 5
|
||||||
|
dilation_rate_flow: int = 1
|
||||||
|
num_layers_flow: int = 4
|
||||||
|
resblock_type_decoder: int = "1"
|
||||||
|
resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11])
|
||||||
|
resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||||
|
upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
|
||||||
|
upsample_initial_channel_decoder: int = 512
|
||||||
|
upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
|
||||||
|
use_sdp: int = True
|
||||||
|
noise_scale: float = 1.0
|
||||||
|
inference_noise_scale: float = 0.667
|
||||||
|
length_scale: int = 1
|
||||||
|
noise_scale_dp: float = 1.0
|
||||||
|
inference_noise_scale_dp: float = 0.8
|
||||||
|
max_inference_len: int = None
|
||||||
|
init_discriminator: bool = True
|
||||||
|
use_spectral_norm_disriminator: bool = False
|
||||||
|
use_speaker_embedding: bool = False
|
||||||
|
num_speakers: int = 0
|
||||||
|
speakers_file: str = None
|
||||||
|
speaker_embedding_channels: int = 256
|
||||||
|
use_d_vector_file: bool = False
|
||||||
|
d_vector_dim: int = 0
|
||||||
|
detach_dp_input: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class Vits(BaseTTS):
|
||||||
|
"""VITS TTS model
|
||||||
|
|
||||||
|
Paper::
|
||||||
|
https://arxiv.org/pdf/2106.06103.pdf
|
||||||
|
|
||||||
|
Paper Abstract::
|
||||||
|
Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel
|
||||||
|
sampling have been proposed, but their sample quality does not match that of two-stage TTS systems.
|
||||||
|
In this work, we present a parallel endto-end TTS method that generates more natural sounding audio than
|
||||||
|
current two-stage models. Our method adopts variational inference augmented with normalizing flows and
|
||||||
|
an adversarial training process, which improves the expressive power of generative modeling. We also propose a
|
||||||
|
stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the
|
||||||
|
uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the
|
||||||
|
natural one-to-many relationship in which a text input can be spoken in multiple ways
|
||||||
|
with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS)
|
||||||
|
on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly
|
||||||
|
available TTS systems and achieves a MOS comparable to ground truth.
|
||||||
|
|
||||||
|
Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from TTS.tts.configs import VitsConfig
|
||||||
|
>>> from TTS.tts.models.vits import Vits
|
||||||
|
>>> config = VitsConfig()
|
||||||
|
>>> model = Vits(config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=dangerous-default-value
|
||||||
|
|
||||||
|
def __init__(self, config: Coqpit):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.END2END = True
|
||||||
|
|
||||||
|
if config.__class__.__name__ == "VitsConfig":
|
||||||
|
# loading from VitsConfig
|
||||||
|
if "num_chars" not in config:
|
||||||
|
_, self.config, num_chars = self.get_characters(config)
|
||||||
|
config.model_args.num_chars = num_chars
|
||||||
|
else:
|
||||||
|
self.config = config
|
||||||
|
config.model_args.num_chars = config.num_chars
|
||||||
|
args = self.config.model_args
|
||||||
|
elif isinstance(config, VitsArgs):
|
||||||
|
# loading from VitsArgs
|
||||||
|
self.config = config
|
||||||
|
args = config
|
||||||
|
else:
|
||||||
|
raise ValueError("config must be either a VitsConfig or VitsArgs")
|
||||||
|
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
self.init_multispeaker(config)
|
||||||
|
|
||||||
|
self.length_scale = args.length_scale
|
||||||
|
self.noise_scale = args.noise_scale
|
||||||
|
self.inference_noise_scale = args.inference_noise_scale
|
||||||
|
self.inference_noise_scale_dp = args.inference_noise_scale_dp
|
||||||
|
self.noise_scale_dp = args.noise_scale_dp
|
||||||
|
self.max_inference_len = args.max_inference_len
|
||||||
|
self.spec_segment_size = args.spec_segment_size
|
||||||
|
|
||||||
|
self.text_encoder = TextEncoder(
|
||||||
|
args.num_chars,
|
||||||
|
args.hidden_channels,
|
||||||
|
args.hidden_channels,
|
||||||
|
args.hidden_channels_ffn_text_encoder,
|
||||||
|
args.num_heads_text_encoder,
|
||||||
|
args.num_layers_text_encoder,
|
||||||
|
args.kernel_size_text_encoder,
|
||||||
|
args.dropout_p_text_encoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.posterior_encoder = PosteriorEncoder(
|
||||||
|
args.out_channels,
|
||||||
|
args.hidden_channels,
|
||||||
|
args.hidden_channels,
|
||||||
|
kernel_size=args.kernel_size_posterior_encoder,
|
||||||
|
dilation_rate=args.dilation_rate_posterior_encoder,
|
||||||
|
num_layers=args.num_layers_posterior_encoder,
|
||||||
|
cond_channels=self.embedded_speaker_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.flow = ResidualCouplingBlocks(
|
||||||
|
args.hidden_channels,
|
||||||
|
args.hidden_channels,
|
||||||
|
kernel_size=args.kernel_size_flow,
|
||||||
|
dilation_rate=args.dilation_rate_flow,
|
||||||
|
num_layers=args.num_layers_flow,
|
||||||
|
cond_channels=self.embedded_speaker_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.use_sdp:
|
||||||
|
self.duration_predictor = StochasticDurationPredictor(
|
||||||
|
args.hidden_channels, 192, 3, 0.5, 4, cond_channels=self.embedded_speaker_dim
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.duration_predictor = DurationPredictor(
|
||||||
|
args.hidden_channels, 256, 3, 0.5, cond_channels=self.embedded_speaker_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
self.waveform_decoder = HifiganGenerator(
|
||||||
|
args.hidden_channels,
|
||||||
|
1,
|
||||||
|
args.resblock_type_decoder,
|
||||||
|
args.resblock_dilation_sizes_decoder,
|
||||||
|
args.resblock_kernel_sizes_decoder,
|
||||||
|
args.upsample_kernel_sizes_decoder,
|
||||||
|
args.upsample_initial_channel_decoder,
|
||||||
|
args.upsample_rates_decoder,
|
||||||
|
inference_padding=0,
|
||||||
|
cond_channels=self.embedded_speaker_dim,
|
||||||
|
conv_pre_weight_norm=False,
|
||||||
|
conv_post_weight_norm=False,
|
||||||
|
conv_post_bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.init_discriminator:
|
||||||
|
self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator)
|
||||||
|
|
||||||
|
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||||
|
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||||
|
or with external `d_vectors` computed from a speaker encoder model.
|
||||||
|
|
||||||
|
If you need a different behaviour, override this function for your model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Coqpit): Model configuration.
|
||||||
|
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||||
|
"""
|
||||||
|
if hasattr(config, "model_args"):
|
||||||
|
config = config.model_args
|
||||||
|
self.embedded_speaker_dim = 0
|
||||||
|
# init speaker manager
|
||||||
|
self.speaker_manager = get_speaker_manager(config, data=data)
|
||||||
|
if config.num_speakers > 0 and self.speaker_manager.num_speakers == 0:
|
||||||
|
self.speaker_manager.num_speakers = config.num_speakers
|
||||||
|
self.num_speakers = self.speaker_manager.num_speakers
|
||||||
|
# init speaker embedding layer
|
||||||
|
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||||
|
self.embedded_speaker_dim = config.speaker_embedding_channels
|
||||||
|
self.emb_g = nn.Embedding(config.num_speakers, config.speaker_embedding_channels)
|
||||||
|
# init d-vector usage
|
||||||
|
if config.use_d_vector_file:
|
||||||
|
self.embedded_speaker_dim = config.d_vector_dim
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _set_cond_input(aux_input: Dict):
|
||||||
|
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
||||||
|
sid, g = None, None
|
||||||
|
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||||
|
sid = aux_input["speaker_ids"]
|
||||||
|
if sid.ndim == 0:
|
||||||
|
sid = sid.unsqueeze_(0)
|
||||||
|
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
|
||||||
|
g = aux_input["d_vectors"]
|
||||||
|
return sid, g
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.tensor,
|
||||||
|
x_lengths: torch.tensor,
|
||||||
|
y: torch.tensor,
|
||||||
|
y_lengths: torch.tensor,
|
||||||
|
aux_input={"d_vectors": None, "speaker_ids": None},
|
||||||
|
) -> Dict:
|
||||||
|
"""Forward pass of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.tensor): Batch of input character sequence IDs.
|
||||||
|
x_lengths (torch.tensor): Batch of input character sequence lengths.
|
||||||
|
y (torch.tensor): Batch of input spectrograms.
|
||||||
|
y_lengths (torch.tensor): Batch of input spectrogram lengths.
|
||||||
|
aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: model outputs keyed by the output name.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, T_seq]`
|
||||||
|
- x_lengths: :math:`[B]`
|
||||||
|
- y: :math:`[B, C, T_spec]`
|
||||||
|
- y_lengths: :math:`[B]`
|
||||||
|
- d_vectors: :math:`[B, C, 1]`
|
||||||
|
- speaker_ids: :math:`[B]`
|
||||||
|
"""
|
||||||
|
outputs = {}
|
||||||
|
sid, g = self._set_cond_input(aux_input)
|
||||||
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
|
||||||
|
|
||||||
|
# speaker embedding
|
||||||
|
if self.num_speakers > 1 and sid is not None:
|
||||||
|
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
|
# posterior encoder
|
||||||
|
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||||
|
|
||||||
|
# flow layers
|
||||||
|
z_p = self.flow(z, y_mask, g=g)
|
||||||
|
|
||||||
|
# find the alignment path
|
||||||
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
|
with torch.no_grad():
|
||||||
|
o_scale = torch.exp(-2 * logs_p)
|
||||||
|
# logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
||||||
|
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
|
||||||
|
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
||||||
|
# logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||||
|
logp = logp2 + logp3
|
||||||
|
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||||
|
|
||||||
|
# duration predictor
|
||||||
|
attn_durations = attn.sum(3)
|
||||||
|
if self.args.use_sdp:
|
||||||
|
nll_duration = self.duration_predictor(
|
||||||
|
x.detach() if self.args.detach_dp_input else x,
|
||||||
|
x_mask,
|
||||||
|
attn_durations,
|
||||||
|
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||||
|
)
|
||||||
|
nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask))
|
||||||
|
outputs["nll_duration"] = nll_duration
|
||||||
|
else:
|
||||||
|
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
|
||||||
|
log_durations = self.duration_predictor(
|
||||||
|
x.detach() if self.args.detach_dp_input else x,
|
||||||
|
x_mask,
|
||||||
|
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||||
|
)
|
||||||
|
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||||
|
outputs["loss_duration"] = loss_duration
|
||||||
|
|
||||||
|
# expand prior
|
||||||
|
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||||
|
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||||
|
|
||||||
|
# select a random feature segment for the waveform decoder
|
||||||
|
z_slice, slice_ids = rand_segment(z, y_lengths, self.spec_segment_size)
|
||||||
|
o = self.waveform_decoder(z_slice, g=g)
|
||||||
|
outputs.update(
|
||||||
|
{
|
||||||
|
"model_outputs": o,
|
||||||
|
"alignments": attn.squeeze(1),
|
||||||
|
"slice_ids": slice_ids,
|
||||||
|
"z": z,
|
||||||
|
"z_p": z_p,
|
||||||
|
"m_p": m_p,
|
||||||
|
"logs_p": logs_p,
|
||||||
|
"m_q": m_q,
|
||||||
|
"logs_q": logs_q,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, T_seq]`
|
||||||
|
- d_vectors: :math:`[B, C, 1]`
|
||||||
|
- speaker_ids: :math:`[B]`
|
||||||
|
"""
|
||||||
|
sid, g = self._set_cond_input(aux_input)
|
||||||
|
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
|
|
||||||
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
|
||||||
|
|
||||||
|
if self.num_speakers > 0 and sid:
|
||||||
|
g = self.emb_g(sid).unsqueeze(-1)
|
||||||
|
|
||||||
|
if self.args.use_sdp:
|
||||||
|
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp)
|
||||||
|
else:
|
||||||
|
logw = self.duration_predictor(x, x_mask, g=g)
|
||||||
|
|
||||||
|
w = torch.exp(logw) * x_mask * self.length_scale
|
||||||
|
w_ceil = torch.ceil(w)
|
||||||
|
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||||
|
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype)
|
||||||
|
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||||
|
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
||||||
|
|
||||||
|
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
||||||
|
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||||
|
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||||
|
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
|
||||||
|
|
||||||
|
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
|
||||||
|
"""TODO: create an end-point for voice conversion"""
|
||||||
|
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
|
||||||
|
g_src = self.emb_g(sid_src).unsqueeze(-1)
|
||||||
|
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
|
||||||
|
z, _, _, y_mask = self.enc_q(y, y_lengths, g=g_src)
|
||||||
|
z_p = self.flow(z, y_mask, g=g_src)
|
||||||
|
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
||||||
|
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
||||||
|
return o_hat, y_mask, (z, z_p, z_hat)
|
||||||
|
|
||||||
|
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||||
|
"""Perform a single training step. Run the model forward pass and compute losses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (Dict): Input tensors.
|
||||||
|
criterion (nn.Module): Loss layer designed for the model.
|
||||||
|
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
||||||
|
"""
|
||||||
|
# pylint: disable=attribute-defined-outside-init
|
||||||
|
if optimizer_idx not in [0, 1]:
|
||||||
|
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||||
|
|
||||||
|
if optimizer_idx == 0:
|
||||||
|
text_input = batch["text_input"]
|
||||||
|
text_lengths = batch["text_lengths"]
|
||||||
|
mel_lengths = batch["mel_lengths"]
|
||||||
|
linear_input = batch["linear_input"]
|
||||||
|
d_vectors = batch["d_vectors"]
|
||||||
|
speaker_ids = batch["speaker_ids"]
|
||||||
|
waveform = batch["waveform"]
|
||||||
|
|
||||||
|
# generator pass
|
||||||
|
outputs = self.forward(
|
||||||
|
text_input,
|
||||||
|
text_lengths,
|
||||||
|
linear_input.transpose(1, 2),
|
||||||
|
mel_lengths,
|
||||||
|
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
|
||||||
|
)
|
||||||
|
|
||||||
|
# cache tensors for the discriminator
|
||||||
|
self.y_disc_cache = None
|
||||||
|
self.wav_seg_disc_cache = None
|
||||||
|
self.y_disc_cache = outputs["model_outputs"]
|
||||||
|
wav_seg = segment(
|
||||||
|
waveform.transpose(1, 2),
|
||||||
|
outputs["slice_ids"] * self.config.audio.hop_length,
|
||||||
|
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||||
|
)
|
||||||
|
self.wav_seg_disc_cache = wav_seg
|
||||||
|
outputs["waveform_seg"] = wav_seg
|
||||||
|
|
||||||
|
# compute discriminator scores and features
|
||||||
|
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(outputs["model_outputs"])
|
||||||
|
_, outputs["feats_disc_real"] = self.disc(wav_seg)
|
||||||
|
|
||||||
|
# compute losses
|
||||||
|
with autocast(enabled=False): # use float32 for the criterion
|
||||||
|
loss_dict = criterion[optimizer_idx](
|
||||||
|
waveform_hat=outputs["model_outputs"].float(),
|
||||||
|
waveform=wav_seg.float(),
|
||||||
|
z_p=outputs["z_p"].float(),
|
||||||
|
logs_q=outputs["logs_q"].float(),
|
||||||
|
m_p=outputs["m_p"].float(),
|
||||||
|
logs_p=outputs["logs_p"].float(),
|
||||||
|
z_len=mel_lengths,
|
||||||
|
scores_disc_fake=outputs["scores_disc_fake"],
|
||||||
|
feats_disc_fake=outputs["feats_disc_fake"],
|
||||||
|
feats_disc_real=outputs["feats_disc_real"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# handle the duration loss
|
||||||
|
if self.args.use_sdp:
|
||||||
|
loss_dict["nll_duration"] = outputs["nll_duration"]
|
||||||
|
loss_dict["loss"] += outputs["nll_duration"]
|
||||||
|
else:
|
||||||
|
loss_dict["loss_duration"] = outputs["loss_duration"]
|
||||||
|
loss_dict["loss"] += outputs["nll_duration"]
|
||||||
|
|
||||||
|
elif optimizer_idx == 1:
|
||||||
|
# discriminator pass
|
||||||
|
outputs = {}
|
||||||
|
|
||||||
|
# compute scores and features
|
||||||
|
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(self.y_disc_cache.detach())
|
||||||
|
outputs["scores_disc_real"], outputs["feats_disc_real"] = self.disc(self.wav_seg_disc_cache)
|
||||||
|
|
||||||
|
# compute loss
|
||||||
|
with autocast(enabled=False): # use float32 for the criterion
|
||||||
|
loss_dict = criterion[optimizer_idx](
|
||||||
|
outputs["scores_disc_real"],
|
||||||
|
outputs["scores_disc_fake"],
|
||||||
|
)
|
||||||
|
return outputs, loss_dict
|
||||||
|
|
||||||
|
def train_log(
|
||||||
|
self, ap: AudioProcessor, batch: Dict, outputs: List, name_prefix="train"
|
||||||
|
): # pylint: disable=no-self-use
|
||||||
|
"""Create visualizations and waveform examples.
|
||||||
|
|
||||||
|
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
|
||||||
|
be projected onto Tensorboard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ap (AudioProcessor): audio processor used at training.
|
||||||
|
batch (Dict): Model inputs used at the previous training step.
|
||||||
|
outputs (Dict): Model outputs generated at the previoud training step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||||
|
"""
|
||||||
|
y_hat = outputs[0]["model_outputs"]
|
||||||
|
y = outputs[0]["waveform_seg"]
|
||||||
|
figures = plot_results(y_hat, y, ap, name_prefix)
|
||||||
|
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||||
|
audios = {f"{name_prefix}/audio": sample_voice}
|
||||||
|
|
||||||
|
alignments = outputs[0]["alignments"]
|
||||||
|
align_img = alignments[0].data.cpu().numpy().T
|
||||||
|
|
||||||
|
figures.update(
|
||||||
|
{
|
||||||
|
"alignment": plot_alignment(align_img, output_fig=False),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return figures, audios
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
||||||
|
return self.train_step(batch, criterion, optimizer_idx)
|
||||||
|
|
||||||
|
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||||
|
return self.train_log(ap, batch, outputs, "eval")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_run(self, ap) -> Tuple[Dict, Dict]:
|
||||||
|
"""Generic test run for `tts` models used by `Trainer`.
|
||||||
|
|
||||||
|
You can override this for a different behaviour.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||||
|
"""
|
||||||
|
print(" | > Synthesizing test sentences.")
|
||||||
|
test_audios = {}
|
||||||
|
test_figures = {}
|
||||||
|
test_sentences = self.config.test_sentences
|
||||||
|
aux_inputs = self.get_aux_input()
|
||||||
|
for idx, sen in enumerate(test_sentences):
|
||||||
|
wav, alignment, _, _ = synthesis(
|
||||||
|
self,
|
||||||
|
sen,
|
||||||
|
self.config,
|
||||||
|
"cuda" in str(next(self.parameters()).device),
|
||||||
|
ap,
|
||||||
|
speaker_id=aux_inputs["speaker_id"],
|
||||||
|
d_vector=aux_inputs["d_vector"],
|
||||||
|
style_wav=aux_inputs["style_wav"],
|
||||||
|
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||||
|
use_griffin_lim=True,
|
||||||
|
do_trim_silence=False,
|
||||||
|
).values()
|
||||||
|
|
||||||
|
test_audios["{}-audio".format(idx)] = wav
|
||||||
|
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
||||||
|
return test_figures, test_audios
|
||||||
|
|
||||||
|
def get_optimizer(self) -> List:
|
||||||
|
"""Initiate and return the GAN optimizers based on the config parameters.
|
||||||
|
|
||||||
|
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: optimizers.
|
||||||
|
"""
|
||||||
|
self.disc.requires_grad_(False)
|
||||||
|
gen_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||||
|
self.disc.requires_grad_(True)
|
||||||
|
optimizer1 = get_optimizer(
|
||||||
|
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
||||||
|
)
|
||||||
|
optimizer2 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
|
||||||
|
return [optimizer1, optimizer2]
|
||||||
|
|
||||||
|
def get_lr(self) -> List:
|
||||||
|
"""Set the initial learning rates for each optimizer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: learning rates for each optimizer.
|
||||||
|
"""
|
||||||
|
return [self.config.lr_gen, self.config.lr_disc]
|
||||||
|
|
||||||
|
def get_scheduler(self, optimizer) -> List:
|
||||||
|
"""Set the schedulers for each optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: Schedulers, one for each optimizer.
|
||||||
|
"""
|
||||||
|
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
|
||||||
|
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||||
|
return [scheduler1, scheduler2]
|
||||||
|
|
||||||
|
def get_criterion(self):
|
||||||
|
"""Get criterions for each optimizer. The index in the output list matches the optimizer idx used in
|
||||||
|
`train_step()`"""
|
||||||
|
from TTS.tts.layers.losses import ( # pylint: disable=import-outside-toplevel
|
||||||
|
VitsDiscriminatorLoss,
|
||||||
|
VitsGeneratorLoss,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_symbols(config):
|
||||||
|
"""Create a custom arrangement of symbols used by the model. The output list of symbols propagate along the
|
||||||
|
whole training and inference steps."""
|
||||||
|
_pad = config.characters["pad"]
|
||||||
|
_punctuations = config.characters["punctuations"]
|
||||||
|
_letters = config.characters["characters"]
|
||||||
|
_letters_ipa = config.characters["phonemes"]
|
||||||
|
symbols = [_pad] + list(_punctuations) + list(_letters)
|
||||||
|
if config.use_phonemes:
|
||||||
|
symbols += list(_letters_ipa)
|
||||||
|
return symbols
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_characters(config: Coqpit):
|
||||||
|
if config.characters is not None:
|
||||||
|
symbols = Vits.make_symbols(config)
|
||||||
|
else:
|
||||||
|
from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel
|
||||||
|
parse_symbols,
|
||||||
|
phonemes,
|
||||||
|
symbols,
|
||||||
|
)
|
||||||
|
|
||||||
|
config.characters = parse_symbols()
|
||||||
|
if config.use_phonemes:
|
||||||
|
symbols = phonemes
|
||||||
|
num_chars = len(symbols) + getattr(config, "add_blank", False)
|
||||||
|
return symbols, config, num_chars
|
||||||
|
|
||||||
|
def load_checkpoint(
|
||||||
|
self, config, checkpoint_path, eval=False
|
||||||
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
|
"""Load the model checkpoint and setup for training or inference"""
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
|
self.load_state_dict(state["model"])
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
assert not self.training
|
|
@ -81,7 +81,6 @@ def text2phone(text, language, use_espeak_phonemes=False):
|
||||||
# Fix a few phonemes
|
# Fix a few phonemes
|
||||||
ph = ph.translate(GRUUT_TRANS_TABLE)
|
ph = ph.translate(GRUUT_TRANS_TABLE)
|
||||||
|
|
||||||
# print(" > Phonemes: {}".format(ph))
|
|
||||||
return ph
|
return ph
|
||||||
|
|
||||||
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
|
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
|
||||||
|
@ -116,6 +115,7 @@ def phoneme_to_sequence(
|
||||||
use_espeak_phonemes: bool = False,
|
use_espeak_phonemes: bool = False,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""Converts a string of phonemes to a sequence of IDs.
|
"""Converts a string of phonemes to a sequence of IDs.
|
||||||
|
If `custom_symbols` is provided, it will override the default symbols.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (str): string to convert to a sequence
|
text (str): string to convert to a sequence
|
||||||
|
@ -132,12 +132,11 @@ def phoneme_to_sequence(
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
global _phonemes_to_id, _phonemes
|
global _phonemes_to_id, _phonemes
|
||||||
|
|
||||||
if tp:
|
if custom_symbols is not None:
|
||||||
|
_phonemes = custom_symbols
|
||||||
|
elif tp:
|
||||||
_, _phonemes = make_symbols(**tp)
|
_, _phonemes = make_symbols(**tp)
|
||||||
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
||||||
elif custom_symbols is not None:
|
|
||||||
_phonemes = custom_symbols
|
|
||||||
_phonemes_to_id = {s: i for i, s in enumerate(custom_symbols)}
|
|
||||||
|
|
||||||
sequence = []
|
sequence = []
|
||||||
clean_text = _clean_text(text, cleaner_names)
|
clean_text = _clean_text(text, cleaner_names)
|
||||||
|
@ -155,14 +154,17 @@ def phoneme_to_sequence(
|
||||||
return sequence
|
return sequence
|
||||||
|
|
||||||
|
|
||||||
def sequence_to_phoneme(sequence, tp=None, add_blank=False):
|
def sequence_to_phoneme(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List["str"] = None):
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
"""Converts a sequence of IDs back to a string"""
|
"""Converts a sequence of IDs back to a string"""
|
||||||
global _id_to_phonemes, _phonemes
|
global _id_to_phonemes, _phonemes
|
||||||
if add_blank:
|
if add_blank:
|
||||||
sequence = list(filter(lambda x: x != len(_phonemes), sequence))
|
sequence = list(filter(lambda x: x != len(_phonemes), sequence))
|
||||||
result = ""
|
result = ""
|
||||||
if tp:
|
|
||||||
|
if custom_symbols is not None:
|
||||||
|
_phonemes = custom_symbols
|
||||||
|
elif tp:
|
||||||
_, _phonemes = make_symbols(**tp)
|
_, _phonemes = make_symbols(**tp)
|
||||||
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
|
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
|
||||||
|
|
||||||
|
@ -177,6 +179,7 @@ def text_to_sequence(
|
||||||
text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False
|
text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||||
|
If `custom_symbols` is provided, it will override the default symbols.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (str): string to convert to a sequence
|
text (str): string to convert to a sequence
|
||||||
|
@ -189,12 +192,12 @@ def text_to_sequence(
|
||||||
"""
|
"""
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
global _symbol_to_id, _symbols
|
global _symbol_to_id, _symbols
|
||||||
if tp:
|
|
||||||
|
if custom_symbols is not None:
|
||||||
|
_symbols = custom_symbols
|
||||||
|
elif tp:
|
||||||
_symbols, _ = make_symbols(**tp)
|
_symbols, _ = make_symbols(**tp)
|
||||||
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
|
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
|
||||||
elif custom_symbols is not None:
|
|
||||||
_symbols = custom_symbols
|
|
||||||
_symbol_to_id = {s: i for i, s in enumerate(custom_symbols)}
|
|
||||||
|
|
||||||
sequence = []
|
sequence = []
|
||||||
|
|
||||||
|
@ -213,14 +216,16 @@ def text_to_sequence(
|
||||||
return sequence
|
return sequence
|
||||||
|
|
||||||
|
|
||||||
def sequence_to_text(sequence, tp=None, add_blank=False):
|
def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List[str] = None):
|
||||||
"""Converts a sequence of IDs back to a string"""
|
"""Converts a sequence of IDs back to a string"""
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
global _id_to_symbol, _symbols
|
global _id_to_symbol, _symbols
|
||||||
if add_blank:
|
if add_blank:
|
||||||
sequence = list(filter(lambda x: x != len(_symbols), sequence))
|
sequence = list(filter(lambda x: x != len(_symbols), sequence))
|
||||||
|
|
||||||
if tp:
|
if custom_symbols is not None:
|
||||||
|
_symbols = custom_symbols
|
||||||
|
elif tp:
|
||||||
_symbols, _ = make_symbols(**tp)
|
_symbols, _ = make_symbols(**tp)
|
||||||
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
||||||
|
|
||||||
|
|
|
@ -96,10 +96,12 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
)
|
)
|
||||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||||
|
|
||||||
def _amp_to_db(self, x, spec_gain=1.0):
|
@staticmethod
|
||||||
|
def _amp_to_db(x, spec_gain=1.0):
|
||||||
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
|
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
|
||||||
|
|
||||||
def _db_to_amp(self, x, spec_gain=1.0):
|
@staticmethod
|
||||||
|
def _db_to_amp(x, spec_gain=1.0):
|
||||||
return torch.exp(x) / spec_gain
|
return torch.exp(x) / spec_gain
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -33,10 +33,10 @@ class DiscriminatorP(torch.nn.Module):
|
||||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||||
self.convs = nn.ModuleList(
|
self.convs = nn.ModuleList(
|
||||||
[
|
[
|
||||||
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||||
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||||
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||||
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||||
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -81,15 +81,15 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
||||||
Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
|
Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, use_spectral_norm=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.discriminators = nn.ModuleList(
|
self.discriminators = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DiscriminatorP(2),
|
DiscriminatorP(2, use_spectral_norm=use_spectral_norm),
|
||||||
DiscriminatorP(3),
|
DiscriminatorP(3, use_spectral_norm=use_spectral_norm),
|
||||||
DiscriminatorP(5),
|
DiscriminatorP(5, use_spectral_norm=use_spectral_norm),
|
||||||
DiscriminatorP(7),
|
DiscriminatorP(7, use_spectral_norm=use_spectral_norm),
|
||||||
DiscriminatorP(11),
|
DiscriminatorP(11, use_spectral_norm=use_spectral_norm),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -170,6 +170,10 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
upsample_initial_channel,
|
upsample_initial_channel,
|
||||||
upsample_factors,
|
upsample_factors,
|
||||||
inference_padding=5,
|
inference_padding=5,
|
||||||
|
cond_channels=0,
|
||||||
|
conv_pre_weight_norm=True,
|
||||||
|
conv_post_weight_norm=True,
|
||||||
|
conv_post_bias=True,
|
||||||
):
|
):
|
||||||
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
|
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
|
||||||
|
|
||||||
|
@ -218,12 +222,21 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||||
self.resblocks.append(resblock(ch, k, d))
|
self.resblocks.append(resblock(ch, k, d))
|
||||||
# post convolution layer
|
# post convolution layer
|
||||||
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3))
|
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
|
||||||
|
if cond_channels > 0:
|
||||||
|
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||||
|
|
||||||
def forward(self, x):
|
if not conv_pre_weight_norm:
|
||||||
|
remove_weight_norm(self.conv_pre)
|
||||||
|
|
||||||
|
if not conv_post_weight_norm:
|
||||||
|
remove_weight_norm(self.conv_post)
|
||||||
|
|
||||||
|
def forward(self, x, g=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x (Tensor): conditioning input tensor.
|
x (Tensor): feature input tensor.
|
||||||
|
g (Tensor): global conditioning input tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: output waveform.
|
Tensor: output waveform.
|
||||||
|
@ -233,6 +246,8 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
Tensor: [B, 1, T]
|
Tensor: [B, 1, T]
|
||||||
"""
|
"""
|
||||||
o = self.conv_pre(x)
|
o = self.conv_pre(x)
|
||||||
|
if hasattr(self, "cond_layer"):
|
||||||
|
o = o + self.cond_layer(g)
|
||||||
for i in range(self.num_upsamples):
|
for i in range(self.num_upsamples):
|
||||||
o = F.leaky_relu(o, LRELU_SLOPE)
|
o = F.leaky_relu(o, LRELU_SLOPE)
|
||||||
o = self.ups[i](o)
|
o = self.ups[i](o)
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -10,18 +12,35 @@ LRELU_SLOPE = 0.1
|
||||||
class UnivnetGenerator(torch.nn.Module):
|
class UnivnetGenerator(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels,
|
in_channels: int,
|
||||||
out_channels,
|
out_channels: int,
|
||||||
hidden_channels,
|
hidden_channels: int,
|
||||||
cond_channels,
|
cond_channels: int,
|
||||||
upsample_factors,
|
upsample_factors: List[int],
|
||||||
lvc_layers_each_block,
|
lvc_layers_each_block: int,
|
||||||
lvc_kernel_size,
|
lvc_kernel_size: int,
|
||||||
kpnet_hidden_channels,
|
kpnet_hidden_channels: int,
|
||||||
kpnet_conv_size,
|
kpnet_conv_size: int,
|
||||||
dropout,
|
dropout: float,
|
||||||
use_weight_norm=True,
|
use_weight_norm=True,
|
||||||
):
|
):
|
||||||
|
"""Univnet Generator network.
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/pdf/2106.07889.pdf
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input tensor channels.
|
||||||
|
out_channels (int): Number of channels of the output tensor.
|
||||||
|
hidden_channels (int): Number of hidden network channels.
|
||||||
|
cond_channels (int): Number of channels of the conditioning tensors.
|
||||||
|
upsample_factors (List[int]): List of uplsample factors for the upsampling layers.
|
||||||
|
lvc_layers_each_block (int): Number of LVC layers in each block.
|
||||||
|
lvc_kernel_size (int): Kernel size of the LVC layers.
|
||||||
|
kpnet_hidden_channels (int): Number of hidden channels in the key-point network.
|
||||||
|
kpnet_conv_size (int): Number of convolution channels in the key-point network.
|
||||||
|
dropout (float): Dropout rate.
|
||||||
|
use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True.
|
||||||
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from TTS.tts.utils.visual import plot_spectrogram
|
from TTS.tts.utils.visual import plot_spectrogram
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
def interpolate_vocoder_input(scale_factor, spec):
|
def interpolate_vocoder_input(scale_factor, spec):
|
||||||
|
@ -26,12 +29,24 @@ def interpolate_vocoder_input(scale_factor, spec):
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def plot_results(y_hat, y, ap, name_prefix):
|
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
|
||||||
"""Plot vocoder model results"""
|
"""Plot the predicted and the real waveform and their spectrograms.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_hat (torch.tensor): Predicted waveform.
|
||||||
|
y (torch.tensor): Real waveform.
|
||||||
|
ap (AudioProcessor): Audio processor used to process the waveform.
|
||||||
|
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: output figures keyed by the name of the figures.
|
||||||
|
""" """Plot vocoder model results"""
|
||||||
|
if name_prefix is None:
|
||||||
|
name_prefix = ""
|
||||||
|
|
||||||
# select an instance from batch
|
# select an instance from batch
|
||||||
y_hat = y_hat[0].squeeze(0).detach().cpu().numpy()
|
y_hat = y_hat[0].squeeze().detach().cpu().numpy()
|
||||||
y = y[0].squeeze(0).detach().cpu().numpy()
|
y = y[0].squeeze().detach().cpu().numpy()
|
||||||
|
|
||||||
spec_fake = ap.melspectrogram(y_hat).T
|
spec_fake = ap.melspectrogram(y_hat).T
|
||||||
spec_real = ap.melspectrogram(y).T
|
spec_real = ap.melspectrogram(y).T
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
|
from TTS.tts.configs import VitsConfig
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
|
||||||
|
config = VitsConfig(
|
||||||
|
batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_eval_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
use_espeak_phonemes=True,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
print_eval=True,
|
||||||
|
test_sentences=[
|
||||||
|
"Be a voice, not an echo.",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
config.audio.do_trim_silence = True
|
||||||
|
config.audio.trim_db = 60
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# train the model for one epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
|
f"--coqpit.output_path {output_path} "
|
||||||
|
"--coqpit.datasets.0.name ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||||
|
"--coqpit.test_delay_epochs 0"
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
|
||||||
|
# Find latest folder
|
||||||
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
|
# restore the model and continue training for one more epoch
|
||||||
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
|
run_cli(command_train)
|
||||||
|
shutil.rmtree(continue_path)
|
Loading…
Reference in New Issue