From c312acac7dd2ec4b51a2709581791ed5a656af0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 9 Aug 2021 08:00:43 +0000 Subject: [PATCH] =?UTF-8?q?Implement=20VITS=20model=20=20=F0=9F=9A=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit VITS model implementation built on Glow TTS and HiFiGAN layers. --- TTS/tts/configs/vits_config.py | 60 ++ TTS/tts/datasets/TTSDataset.py | 5 + TTS/tts/layers/generic/normalization.py | 25 + TTS/tts/layers/glow_tts/duration_predictor.py | 9 +- TTS/tts/layers/glow_tts/glow.py | 2 +- TTS/tts/layers/glow_tts/transformer.py | 74 +- TTS/tts/layers/losses.py | 141 ++++ TTS/tts/layers/vits/discriminator.py | 77 ++ TTS/tts/layers/vits/networks.py | 271 +++++++ .../vits/stochastic_duration_predictor.py | 276 +++++++ TTS/tts/layers/vits/transforms.py | 203 +++++ TTS/tts/models/base_tts.py | 23 +- TTS/tts/models/vits.py | 758 ++++++++++++++++++ TTS/tts/utils/text/__init__.py | 39 +- TTS/utils/audio.py | 6 +- TTS/vocoder/models/hifigan_discriminator.py | 22 +- TTS/vocoder/models/hifigan_generator.py | 21 +- TTS/vocoder/models/univnet_generator.py | 39 +- TTS/vocoder/utils/generic_utils.py | 23 +- tests/tts_tests/test_vits_train.py | 54 ++ 20 files changed, 2055 insertions(+), 73 deletions(-) create mode 100644 TTS/tts/configs/vits_config.py create mode 100644 TTS/tts/layers/vits/discriminator.py create mode 100644 TTS/tts/layers/vits/networks.py create mode 100644 TTS/tts/layers/vits/stochastic_duration_predictor.py create mode 100644 TTS/tts/layers/vits/transforms.py create mode 100644 TTS/tts/models/vits.py create mode 100644 tests/tts_tests/test_vits_train.py diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py new file mode 100644 index 00000000..e64944fe --- /dev/null +++ b/TTS/tts/configs/vits_config.py @@ -0,0 +1,60 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.vits import VitsArgs + + +@dataclass +class VitsConfig(BaseTTSConfig): + """Defines parameters for VITS End2End TTS model. + + Example: + + >>> from TTS.tts.configs import VitsConfig + >>> config = VitsConfig() + """ + + model: str = "vits" + # model specific params + model_args: VitsArgs = field(default_factory=VitsArgs) + + # optimizer + grad_clip: float = field(default_factory=lambda: [5, 5]) + lr_gen: float = 0.0002 + lr_disc: float = 0.0002 + lr_scheduler_gen: str = "ExponentialLR" + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + lr_scheduler_disc: str = "ExponentialLR" + lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + scheduler_after_epoch: bool = True + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01}) + + # loss params + kl_loss_alpha: float = 1.0 + disc_loss_alpha: float = 1.0 + gen_loss_alpha: float = 1.0 + feat_loss_alpha: float = 1.0 + mel_loss_alpha: float = 45.0 + + # data loader params + return_wav: bool = True + compute_linear_spec: bool = True + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + add_blank: bool = True + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index aaa0ba50..89326c9c 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -191,6 +191,7 @@ class TTSDataset(Dataset): else: text, wav_file, speaker_name = item attn = None + raw_text = text wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) @@ -236,6 +237,7 @@ class TTSDataset(Dataset): return self.load_data(self.rescue_item_idx) sample = { + "raw_text": raw_text, "text": text, "wav": wav, "attn": attn, @@ -360,6 +362,7 @@ class TTSDataset(Dataset): wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing] item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing] text = [batch[idx]["text"] for idx in ids_sorted_decreasing] + raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing] speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing] # get pre-computed d-vectors @@ -450,6 +453,7 @@ class TTSDataset(Dataset): attns = torch.FloatTensor(attns).unsqueeze(1) else: attns = None + # TODO: return dictionary return ( text, text_lenghts, @@ -463,6 +467,7 @@ class TTSDataset(Dataset): speaker_ids, attns, wav_padded, + raw_text, ) raise TypeError( diff --git a/TTS/tts/layers/generic/normalization.py b/TTS/tts/layers/generic/normalization.py index fd607b75..4766c77d 100644 --- a/TTS/tts/layers/generic/normalization.py +++ b/TTS/tts/layers/generic/normalization.py @@ -28,6 +28,31 @@ class LayerNorm(nn.Module): return x +class LayerNorm2(nn.Module): + """Layer norm for the 2nd dimension of the input using torch primitive. + Args: + channels (int): number of channels (2nd dimension) of the input. + eps (float): to prevent 0 division + + Shapes: + - input: (B, C, T) + - output: (B, C, T) + """ + + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = torch.nn.functional.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + class TemporalBatchNorm1d(nn.BatchNorm1d): """Normalize each channel separately over time and batch.""" diff --git a/TTS/tts/layers/glow_tts/duration_predictor.py b/TTS/tts/layers/glow_tts/duration_predictor.py index e35aeb68..2c0303be 100644 --- a/TTS/tts/layers/glow_tts/duration_predictor.py +++ b/TTS/tts/layers/glow_tts/duration_predictor.py @@ -18,7 +18,7 @@ class DurationPredictor(nn.Module): dropout_p (float): Dropout rate used after each conv layer. """ - def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p): + def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None): super().__init__() # class arguments self.in_channels = in_channels @@ -33,13 +33,18 @@ class DurationPredictor(nn.Module): self.norm_2 = LayerNorm(hidden_channels) # output layer self.proj = nn.Conv1d(hidden_channels, 1, 1) + if cond_channels is not None and cond_channels != 0: + self.cond = nn.Conv1d(cond_channels, in_channels, 1) - def forward(self, x, x_mask): + def forward(self, x, x_mask, g=None): """ Shapes: - x: :math:`[B, C, T]` - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` """ + if g is not None: + x = x + self.cond(g) x = self.conv_1(x * x_mask) x = torch.relu(x) x = self.norm_1(x) diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py index 33036537..392447de 100644 --- a/TTS/tts/layers/glow_tts/glow.py +++ b/TTS/tts/layers/glow_tts/glow.py @@ -16,7 +16,7 @@ class ResidualConv1dLayerNormBlock(nn.Module): :: x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o - |---------------> conv1d_1x1 -----------------------| + |---------------> conv1d_1x1 ------------------| Args: in_channels (int): number of input tensor channels. diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index 92cace78..ba6aa1e2 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -4,7 +4,7 @@ import torch from torch import nn from torch.nn import functional as F -from TTS.tts.layers.glow_tts.glow import LayerNorm +from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2 class RelativePositionMultiHeadAttention(nn.Module): @@ -271,7 +271,7 @@ class FeedForwardNetwork(nn.Module): dropout_p (float, optional): dropout rate. Defaults to 0. """ - def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0): + def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0, causal=False): super().__init__() self.in_channels = in_channels @@ -280,17 +280,46 @@ class FeedForwardNetwork(nn.Module): self.kernel_size = kernel_size self.dropout_p = dropout_p - self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2) - self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size, padding=kernel_size // 2) + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size) + self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size) self.dropout = nn.Dropout(dropout_p) def forward(self, x, x_mask): - x = self.conv_1(x * x_mask) + x = self.conv_1(self.padding(x * x_mask)) x = torch.relu(x) x = self.dropout(x) - x = self.conv_2(x * x_mask) + x = self.conv_2(self.padding(x * x_mask)) return x * x_mask + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, self._pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, self._pad_shape(padding)) + return x + + @staticmethod + def _pad_shape(padding): + l = padding[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + class RelativePositionTransformer(nn.Module): """Transformer with Relative Potional Encoding. @@ -310,20 +339,23 @@ class RelativePositionTransformer(nn.Module): If default, relative encoding is disabled and it is a regular transformer. Defaults to None. input_length (int, optional): input lenght to limit position encoding. Defaults to None. + layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm + primitive. Use type "2", type "1: is for backward compat. Defaults to "1". """ def __init__( self, - in_channels, - out_channels, - hidden_channels, - hidden_channels_ffn, - num_heads, - num_layers, + in_channels: int, + out_channels: int, + hidden_channels: int, + hidden_channels_ffn: int, + num_heads: int, + num_layers: int, kernel_size=1, dropout_p=0.0, - rel_attn_window_size=None, - input_length=None, + rel_attn_window_size: int = None, + input_length: int = None, + layer_norm_type: str = "1", ): super().__init__() self.hidden_channels = hidden_channels @@ -351,7 +383,12 @@ class RelativePositionTransformer(nn.Module): input_length=input_length, ) ) - self.norm_layers_1.append(LayerNorm(hidden_channels)) + if layer_norm_type == "1": + self.norm_layers_1.append(LayerNorm(hidden_channels)) + elif layer_norm_type == "2": + self.norm_layers_1.append(LayerNorm2(hidden_channels)) + else: + raise ValueError(" [!] Unknown layer norm type") if hidden_channels != out_channels and (idx + 1) == self.num_layers: self.proj = nn.Conv1d(hidden_channels, out_channels, 1) @@ -366,7 +403,12 @@ class RelativePositionTransformer(nn.Module): ) ) - self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels)) + if layer_norm_type == "1": + self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels)) + elif layer_norm_type == "2": + self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels)) + else: + raise ValueError(" [!] Unknown layer norm type") def forward(self, x, x_mask): """ diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 07b58974..171b0217 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -2,11 +2,13 @@ import math import numpy as np import torch +from coqpit import Coqpit from torch import nn from torch.nn import functional from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.ssim import ssim +from TTS.utils.audio import TorchSTFT # pylint: disable=abstract-method @@ -514,3 +516,142 @@ class AlignTTSLoss(nn.Module): + self.mdn_alpha * mdn_loss ) return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss} + + +class VitsGeneratorLoss(nn.Module): + def __init__(self, c: Coqpit): + super().__init__() + self.kl_loss_alpha = c.kl_loss_alpha + self.gen_loss_alpha = c.gen_loss_alpha + self.feat_loss_alpha = c.feat_loss_alpha + self.mel_loss_alpha = c.mel_loss_alpha + self.stft = TorchSTFT( + c.audio.fft_size, + c.audio.hop_length, + c.audio.win_length, + sample_rate=c.audio.sample_rate, + mel_fmin=c.audio.mel_fmin, + mel_fmax=c.audio.mel_fmax, + n_mels=c.audio.num_mels, + use_mel=True, + do_amp_to_db=True, + ) + + @staticmethod + def feature_loss(feats_real, feats_generated): + loss = 0 + for dr, dg in zip(feats_real, feats_generated): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + @staticmethod + def generator_loss(scores_fake): + loss = 0 + gen_losses = [] + for dg in scores_fake: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + @staticmethod + def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l + + def forward( + self, + waveform, + waveform_hat, + z_p, + logs_q, + m_p, + logs_p, + z_len, + scores_disc_fake, + feats_disc_fake, + feats_disc_real, + ): + """ + Shapes: + - wavefrom: :math:`[B, 1, T]` + - waveform_hat: :math:`[B, 1, T]` + - z_p: :math:`[B, C, T]` + - logs_q: :math:`[B, C, T]` + - m_p: :math:`[B, C, T]` + - logs_p: :math:`[B, C, T]` + - z_len: :math:`[B]` + - scores_disc_fake[i]: :math:`[B, C]` + - feats_disc_fake[i][j]: :math:`[B, C, T', P]` + - feats_disc_real[i][j]: :math:`[B, C, T', P]` + """ + loss = 0.0 + return_dict = {} + z_mask = sequence_mask(z_len).float() + # compute mel spectrograms from the waveforms + mel = self.stft(waveform) + mel_hat = self.stft(waveform_hat) + # compute losses + loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha + loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha + loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha + loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha + loss = loss_kl + loss_feat + loss_mel + loss_gen + # pass losses to the dict + return_dict["loss_gen"] = loss_gen + return_dict["loss_kl"] = loss_kl + return_dict["loss_feat"] = loss_feat + return_dict["loss_mel"] = loss_mel + return_dict["loss"] = loss + return return_dict + + +class VitsDiscriminatorLoss(nn.Module): + def __init__(self, c: Coqpit): + super().__init__() + self.disc_loss_alpha = c.disc_loss_alpha + + @staticmethod + def discriminator_loss(scores_real, scores_fake): + loss = 0 + real_losses = [] + fake_losses = [] + for dr, dg in zip(scores_real, scores_fake): + dr = dr.float() + dg = dg.float() + real_loss = torch.mean((1 - dr) ** 2) + fake_loss = torch.mean(dg ** 2) + loss += real_loss + fake_loss + real_losses.append(real_loss.item()) + fake_losses.append(fake_loss.item()) + + return loss, real_losses, fake_losses + + def forward(self, scores_disc_real, scores_disc_fake): + loss = 0.0 + return_dict = {} + loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake) + return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha + loss = loss + loss_disc + return_dict["loss_disc"] = loss_disc + return_dict["loss"] = loss + return return_dict diff --git a/TTS/tts/layers/vits/discriminator.py b/TTS/tts/layers/vits/discriminator.py new file mode 100644 index 00000000..650c9b61 --- /dev/null +++ b/TTS/tts/layers/vits/discriminator.py @@ -0,0 +1,77 @@ +import torch +from torch import nn +from torch.nn.modules.conv import Conv1d + +from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator + + +class DiscriminatorS(torch.nn.Module): + """HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN. + + Args: + use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. + """ + + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + Tensor: discriminator scores. + List[Tensor]: list of features from the convolutiona layers. + """ + feat = [] + for l in self.convs: + x = l(x) + x = torch.nn.functional.leaky_relu(x, 0.1) + feat.append(x) + x = self.conv_post(x) + feat.append(x) + x = torch.flatten(x, 1, -1) + return x, feat + + +class VitsDiscriminator(nn.Module): + """VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator. + + :: + waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats + |--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^ + + Args: + use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. + """ + + def __init__(self, use_spectral_norm=False): + super().__init__() + self.sd = DiscriminatorS(use_spectral_norm=use_spectral_norm) + self.mpd = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + scores, feats = self.mpd(x) + score_sd, feats_sd = self.sd(x) + return scores + [score_sd], feats + [feats_sd] diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py new file mode 100644 index 00000000..cf9d6e41 --- /dev/null +++ b/TTS/tts/layers/vits/networks.py @@ -0,0 +1,271 @@ +import math + +import torch +from torch import nn + +from TTS.tts.layers.glow_tts.glow import WN +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer +from TTS.tts.utils.data import sequence_mask + +LRELU_SLOPE = 0.1 + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class TextEncoder(nn.Module): + def __init__( + self, + n_vocab: int, + out_channels: int, + hidden_channels: int, + hidden_channels_ffn: int, + num_heads: int, + num_layers: int, + kernel_size: int, + dropout_p: float, + ): + """Text Encoder for VITS model. + + Args: + n_vocab (int): Number of characters for the embedding layer. + out_channels (int): Number of channels for the output. + hidden_channels (int): Number of channels for the hidden layers. + hidden_channels_ffn (int): Number of channels for the convolutional layers. + num_heads (int): Number of attention heads for the Transformer layers. + num_layers (int): Number of Transformer layers. + kernel_size (int): Kernel size for the FFN layers in Transformer network. + dropout_p (float): Dropout rate for the Transformer layers. + """ + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + + self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) + + self.encoder = RelativePositionTransformer( + in_channels=hidden_channels, + out_channels=hidden_channels, + hidden_channels=hidden_channels, + hidden_channels_ffn=hidden_channels_ffn, + num_heads=num_heads, + num_layers=num_layers, + kernel_size=kernel_size, + dropout_p=dropout_p, + layer_norm_type="2", + rel_attn_window_size=4, + ) + + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + """ + Shapes: + - x: :math:`[B, T]` + - x_length: :math:`[B]` + """ + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class ResidualCouplingBlock(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + dropout_p=0, + cond_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.half_channels = channels // 2 + self.mean_only = mean_only + # input layer + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + # coupling layers + self.enc = WN( + hidden_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + dropout_p=dropout_p, + c_in_channels=cond_channels, + ) + # output layer + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, log_scale = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + log_scale = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(log_scale) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(log_scale, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-log_scale) * x_mask + x = torch.cat([x0, x1], 1) + return x + + +class ResidualCouplingBlocks(nn.Module): + def __init__( + self, + channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + num_layers: int, + num_flows=4, + cond_channels=0, + ): + """Redisual Coupling blocks for VITS flow layers. + + Args: + channels (int): Number of input and output tensor channels. + hidden_channels (int): Number of hidden network channels. + kernel_size (int): Kernel size of the WaveNet layers. + dilation_rate (int): Dilation rate of the WaveNet layers. + num_layers (int): Number of the WaveNet layers. + num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4. + cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0. + """ + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.num_flows = num_flows + self.cond_channels = cond_channels + + self.flows = nn.ModuleList() + for _ in range(num_flows): + self.flows.append( + ResidualCouplingBlock( + channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + cond_channels=cond_channels, + mean_only=True, + ) + ) + + def forward(self, x, x_mask, g=None, reverse=False): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + x = torch.flip(x, [1]) + else: + for flow in reversed(self.flows): + x = torch.flip(x, [1]) + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + num_layers: int, + cond_channels=0, + ): + """Posterior Encoder of VITS model. + + :: + x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z + + Args: + in_channels (int): Number of input tensor channels. + out_channels (int): Number of output tensor channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size of the WaveNet convolution layers. + dilation_rate (int): Dilation rate of the WaveNet layers. + num_layers (int): Number of the WaveNet layers. + cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.cond_channels = cond_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_lengths: :math:`[B, 1]` + - g: :math:`[B, C, 1]` + """ + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + mean, log_scale = torch.split(stats, self.out_channels, dim=1) + z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask + return z, mean, log_scale, x_mask diff --git a/TTS/tts/layers/vits/stochastic_duration_predictor.py b/TTS/tts/layers/vits/stochastic_duration_predictor.py new file mode 100644 index 00000000..ae1edebb --- /dev/null +++ b/TTS/tts/layers/vits/stochastic_duration_predictor.py @@ -0,0 +1,276 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.generic.normalization import LayerNorm2 +from TTS.tts.layers.vits.transforms import piecewise_rational_quadratic_transform + + +class DilatedDepthSeparableConv(nn.Module): + def __init__(self, channels, kernel_size, num_layers, dropout_p=0.0) -> torch.tensor: + """Dilated Depth-wise Separable Convolution module. + + :: + x |-> DDSConv(x) -> LayerNorm(x) -> GeLU(x) -> Conv1x1(x) -> LayerNorm(x) -> GeLU(x) -> + -> o + |-------------------------------------------------------------------------------------^ + + Args: + channels ([type]): [description] + kernel_size ([type]): [description] + num_layers ([type]): [description] + dropout_p (float, optional): [description]. Defaults to 0.0. + + Returns: + torch.tensor: Network output masked by the input sequence mask. + """ + super().__init__() + self.num_layers = num_layers + + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(num_layers): + dilation = kernel_size ** i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm2(channels)) + self.norms_2.append(LayerNorm2(channels)) + self.dropout = nn.Dropout(dropout_p) + + def forward(self, x, x_mask, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + """ + if g is not None: + x = x + g + for i in range(self.num_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.dropout(y) + x = x + y + return x * x_mask + + +class ElementwiseAffine(nn.Module): + """Element-wise affine transform like no-population stats BatchNorm alternative. + + Args: + channels (int): Number of input tensor channels. + """ + + def __init__(self, channels): + super().__init__() + self.translation = nn.Parameter(torch.zeros(channels, 1)) + self.log_scale = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): # pylint: disable=unused-argument + if not reverse: + y = (x * torch.exp(self.log_scale) + self.translation) * x_mask + logdet = torch.sum(self.log_scale * x_mask, [1, 2]) + return y, logdet + x = (x - self.translation) * torch.exp(-self.log_scale) * x_mask + return x + + +class ConvFlow(nn.Module): + """Dilated depth separable convolutional based spline flow. + + Args: + in_channels (int): Number of input tensor channels. + hidden_channels (int): Number of in network channels. + kernel_size (int): Convolutional kernel size. + num_layers (int): Number of convolutional layers. + num_bins (int, optional): Number of spline bins. Defaults to 10. + tail_bound (float, optional): Tail bound for PRQT. Defaults to 5.0. + """ + + def __init__( + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + num_layers: int, + num_bins=10, + tail_bound=5.0, + ): + super().__init__() + self.num_bins = num_bins + self.tail_bound = tail_bound + self.hidden_channels = hidden_channels + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers, dropout_p=0.0) + self.proj = nn.Conv1d(hidden_channels, self.half_channels * (num_bins * 3 - 1), 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.hidden_channels) + unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.hidden_channels) + unnormalized_derivatives = h[..., 2 * self.num_bins :] + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + return x + + +class StochasticDurationPredictor(nn.Module): + """Stochastic duration predictor with Spline Flows. + + It applies Variational Dequantization and Variationsl Data Augmentation. + + Paper: + SDP: https://arxiv.org/pdf/2106.06103.pdf + Spline Flow: https://arxiv.org/abs/1906.04032 + + :: + ## Inference + + x -> TextCondEncoder() -> Flow() -> dr_hat + noise ----------------------^ + + ## Training + |---------------------| + x -> TextCondEncoder() -> + -> PosteriorEncoder() -> split() -> z_u, z_v -> (d - z_u) -> concat() -> Flow() -> noise + d -> DurCondEncoder() -> ^ | + |------------------------------------------------------------------------------| + + Args: + in_channels (int): Number of input tensor channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size of convolutional layers. + dropout_p (float): Dropout rate. + num_flows (int, optional): Number of flow blocks. Defaults to 4. + cond_channels (int, optional): Number of channels of conditioning tensor. Defaults to 0. + """ + + def __init__( + self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0 + ): + super().__init__() + + # condition encoder text + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p) + self.proj = nn.Conv1d(hidden_channels, hidden_channels, 1) + + # posterior encoder + self.flows = nn.ModuleList() + self.flows.append(ElementwiseAffine(2)) + self.flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)] + + # condition encoder duration + self.post_pre = nn.Conv1d(1, hidden_channels, 1) + self.post_convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p) + self.post_proj = nn.Conv1d(hidden_channels, hidden_channels, 1) + + # flow layers + self.post_flows = nn.ModuleList() + self.post_flows.append(ElementwiseAffine(2)) + self.post_flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)] + + if cond_channels != 0 and cond_channels is not None: + self.cond = nn.Conv1d(cond_channels, hidden_channels, 1) + + def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - dr: :math:`[B, 1, T]` + - g: :math:`[B, C]` + """ + # condition encoder text + x = self.pre(x) + if g is not None: + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert dr is not None + + # condition encoder duration + h = self.post_pre(dr) + h = self.post_convs(h, x_mask) + h = self.post_proj(h) * x_mask + noise = torch.rand(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = noise + + # posterior encoder + logdet_tot_q = 0.0 + for idx, flow in enumerate(self.post_flows): + z_q, logdet_q = flow(z_q, x_mask, g=(x + h)) + logdet_tot_q = logdet_tot_q + logdet_q + if idx > 0: + z_q = torch.flip(z_q, [1]) + + z_u, z_v = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (dr - u) * x_mask + + # posterior encoder - neg log likelihood + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) + nll_posterior_encoder = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (noise ** 2)) * x_mask, [1, 2]) - logdet_tot_q + ) + + z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask + logdet_tot = torch.sum(-z0, [1, 2]) + z = torch.cat([z0, z_v], 1) + + # flow layers + for idx, flow in enumerate(flows): + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + if idx > 0: + z = torch.flip(z, [1]) + + # flow layers - neg log likelihood + nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot + return nll_flow_layers + nll_posterior_encoder + + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.rand(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = torch.flip(z, [1]) + z = flow(z, x_mask, g=x, reverse=reverse) + + z0, _ = torch.split(z, [1, 1], 1) + logw = z0 + return logw diff --git a/TTS/tts/layers/vits/transforms.py b/TTS/tts/layers/vits/transforms.py new file mode 100644 index 00000000..c1505554 --- /dev/null +++ b/TTS/tts/layers/vits/transforms.py @@ -0,0 +1,203 @@ +# adopted from https://github.com/bayesiains/nflows + +import numpy as np +import torch +from torch.nn import functional as F + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs, + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 + + +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index e441cc05..cd4c33d0 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -212,13 +212,22 @@ class BaseTTS(BaseModel): else None, ) - if ( - config.use_phonemes - and config.compute_input_seq_cache - and not os.path.exists(dataset.phoneme_cache_path) - ): - # precompute phonemes to have a better estimate of sequence lengths. - dataset.compute_input_seq(config.num_loader_workers) + if config.use_phonemes and config.compute_input_seq_cache: + if hasattr(self, "eval_data_items") and is_eval: + dataset.items = self.eval_data_items + elif hasattr(self, "train_data_items") and not is_eval: + dataset.items = self.train_data_items + else: + # precompute phonemes to have a better estimate of sequence lengths. + dataset.compute_input_seq(config.num_loader_workers) + + # TODO: find a more efficient solution + # cheap hack - store items in the model state to avoid recomputing when reinit the dataset + if is_eval: + self.eval_data_items = dataset.items + else: + self.train_data_items = dataset.items + dataset.sort_items() sampler = DistributedSampler(dataset) if num_gpus > 1 else None diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py new file mode 100644 index 00000000..9a2eec89 --- /dev/null +++ b/TTS/tts/models/vits.py @@ -0,0 +1,758 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Tuple + +import torch +from coqpit import Coqpit +from torch import nn +from torch.cuda.amp.autocast_mode import autocast + +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path +from TTS.tts.layers.vits.discriminator import VitsDiscriminator +from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder +from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor + +# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.data import sequence_mask +from TTS.tts.utils.speakers import get_speaker_manager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.visual import plot_alignment +from TTS.utils.audio import AudioProcessor +from TTS.utils.trainer_utils import get_optimizer, get_scheduler +from TTS.vocoder.models.hifigan_generator import HifiganGenerator +from TTS.vocoder.utils.generic_utils import plot_results + + +def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4): + """Segment each sample in a batch based on the provided segment indices""" + segments = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + index_start = segment_indices[i] + index_end = index_start + segment_size + segments[i] = x[i, :, index_start:index_end] + return segments + + +def rand_segment(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4): + """Create random segments based on the input lengths.""" + B, _, T = x.size() + if x_lengths is None: + x_lengths = T + max_idxs = x_lengths - segment_size + 1 + assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size." + ids_str = (torch.rand([B]).type_as(x) * max_idxs).long() + ret = segment(x, ids_str, segment_size) + return ret, ids_str + + +@dataclass +class VitsArgs(Coqpit): + """VITS model arguments. + + Args: + + num_chars (int): + Number of characters in the vocabulary. Defaults to 100. + + out_channels (int): + Number of output channels. Defaults to 513. + + spec_segment_size (int): + Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`. + + hidden_channels (int): + Number of hidden channels of the model. Defaults to 192. + + hidden_channels_ffn_text_encoder (int): + Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256. + + num_heads_text_encoder (int): + Number of attention heads of the text encoder transformer. Defaults to 2. + + num_layers_text_encoder (int): + Number of transformer layers in the text encoder. Defaults to 6. + + kernel_size_text_encoder (int): + Kernel size of the text encoder transformer FFN layers. Defaults to 3. + + dropout_p_text_encoder (float): + Dropout rate of the text encoder. Defaults to 0.1. + + kernel_size_posterior_encoder (int): + Kernel size of the posterior encoder's WaveNet layers. Defaults to 5. + + dilatation_posterior_encoder (int): + Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1. + + num_layers_posterior_encoder (int): + Number of posterior encoder's WaveNet layers. Defaults to 16. + + kernel_size_flow (int): + Kernel size of the Residual Coupling layers of the flow network. Defaults to 5. + + dilatation_flow (int): + Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1. + + num_layers_flow (int): + Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6. + + resblock_type_decoder (str): + Type of the residual block in the decoder network. Defaults to "1". + + resblock_kernel_sizes_decoder (List[int]): + Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`. + + resblock_dilation_sizes_decoder (List[List[int]]): + Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`. + + upsample_rates_decoder (List[int]): + Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these + values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`. + + upsample_initial_channel_decoder (int): + Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512. + + upsample_kernel_sizes_decoder (List[int]): + Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`. + + use_sdp (int): + Use Stochastic Duration Predictor. Defaults to True. + + noise_scale (float): + Noise scale used for the sample noise tensor in training. Defaults to 1.0. + + inference_noise_scale (float): + Noise scale used for the sample noise tensor in inference. Defaults to 0.667. + + length_scale (int): + Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1. + + noise_scale_dp (float): + Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0. + + inference_noise_scale_dp (float): + Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8. + + max_inference_len (int): + Maximum inference length to limit the memory use. Defaults to None. + + init_discriminator (bool): + Initialize the disciminator network if set True. Set False for inference. Defaults to True. + + use_spectral_norm_disriminator (bool): + Use spectral normalization over weight norm in the discriminator. Defaults to False. + + use_speaker_embedding (bool): + Enable/Disable speaker embedding for multi-speaker models. Defaults to False. + + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + + speaker_embedding_channels (int): + Number of speaker embedding channels. Defaults to 256. + + use_d_vector_file (bool): + Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False. + + d_vector_dim (int): + Number of d-vector channels. Defaults to 0. + + detach_dp_input (bool): + Detach duration predictor's input from the network for stopping the gradients. Defaults to True. + """ + + num_chars: int = 100 + out_channels: int = 513 + spec_segment_size: int = 32 + hidden_channels: int = 192 + hidden_channels_ffn_text_encoder: int = 768 + num_heads_text_encoder: int = 2 + num_layers_text_encoder: int = 6 + kernel_size_text_encoder: int = 3 + dropout_p_text_encoder: int = 0.1 + kernel_size_posterior_encoder: int = 5 + dilation_rate_posterior_encoder: int = 1 + num_layers_posterior_encoder: int = 16 + kernel_size_flow: int = 5 + dilation_rate_flow: int = 1 + num_layers_flow: int = 4 + resblock_type_decoder: int = "1" + resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) + upsample_initial_channel_decoder: int = 512 + upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + use_sdp: int = True + noise_scale: float = 1.0 + inference_noise_scale: float = 0.667 + length_scale: int = 1 + noise_scale_dp: float = 1.0 + inference_noise_scale_dp: float = 0.8 + max_inference_len: int = None + init_discriminator: bool = True + use_spectral_norm_disriminator: bool = False + use_speaker_embedding: bool = False + num_speakers: int = 0 + speakers_file: str = None + speaker_embedding_channels: int = 256 + use_d_vector_file: bool = False + d_vector_dim: int = 0 + detach_dp_input: bool = True + + +class Vits(BaseTTS): + """VITS TTS model + + Paper:: + https://arxiv.org/pdf/2106.06103.pdf + + Paper Abstract:: + Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel + sampling have been proposed, but their sample quality does not match that of two-stage TTS systems. + In this work, we present a parallel endto-end TTS method that generates more natural sounding audio than + current two-stage models. Our method adopts variational inference augmented with normalizing flows and + an adversarial training process, which improves the expressive power of generative modeling. We also propose a + stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the + uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the + natural one-to-many relationship in which a text input can be spoken in multiple ways + with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS) + on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly + available TTS systems and achieves a MOS comparable to ground truth. + + Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments. + + Examples: + >>> from TTS.tts.configs import VitsConfig + >>> from TTS.tts.models.vits import Vits + >>> config = VitsConfig() + >>> model = Vits(config) + """ + + # pylint: disable=dangerous-default-value + + def __init__(self, config: Coqpit): + + super().__init__() + + self.END2END = True + + if config.__class__.__name__ == "VitsConfig": + # loading from VitsConfig + if "num_chars" not in config: + _, self.config, num_chars = self.get_characters(config) + config.model_args.num_chars = num_chars + else: + self.config = config + config.model_args.num_chars = config.num_chars + args = self.config.model_args + elif isinstance(config, VitsArgs): + # loading from VitsArgs + self.config = config + args = config + else: + raise ValueError("config must be either a VitsConfig or VitsArgs") + + self.args = args + + self.init_multispeaker(config) + + self.length_scale = args.length_scale + self.noise_scale = args.noise_scale + self.inference_noise_scale = args.inference_noise_scale + self.inference_noise_scale_dp = args.inference_noise_scale_dp + self.noise_scale_dp = args.noise_scale_dp + self.max_inference_len = args.max_inference_len + self.spec_segment_size = args.spec_segment_size + + self.text_encoder = TextEncoder( + args.num_chars, + args.hidden_channels, + args.hidden_channels, + args.hidden_channels_ffn_text_encoder, + args.num_heads_text_encoder, + args.num_layers_text_encoder, + args.kernel_size_text_encoder, + args.dropout_p_text_encoder, + ) + + self.posterior_encoder = PosteriorEncoder( + args.out_channels, + args.hidden_channels, + args.hidden_channels, + kernel_size=args.kernel_size_posterior_encoder, + dilation_rate=args.dilation_rate_posterior_encoder, + num_layers=args.num_layers_posterior_encoder, + cond_channels=self.embedded_speaker_dim, + ) + + self.flow = ResidualCouplingBlocks( + args.hidden_channels, + args.hidden_channels, + kernel_size=args.kernel_size_flow, + dilation_rate=args.dilation_rate_flow, + num_layers=args.num_layers_flow, + cond_channels=self.embedded_speaker_dim, + ) + + if args.use_sdp: + self.duration_predictor = StochasticDurationPredictor( + args.hidden_channels, 192, 3, 0.5, 4, cond_channels=self.embedded_speaker_dim + ) + else: + self.duration_predictor = DurationPredictor( + args.hidden_channels, 256, 3, 0.5, cond_channels=self.embedded_speaker_dim + ) + + self.waveform_decoder = HifiganGenerator( + args.hidden_channels, + 1, + args.resblock_type_decoder, + args.resblock_dilation_sizes_decoder, + args.resblock_kernel_sizes_decoder, + args.upsample_kernel_sizes_decoder, + args.upsample_initial_channel_decoder, + args.upsample_rates_decoder, + inference_padding=0, + cond_channels=self.embedded_speaker_dim, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + ) + + if args.init_discriminator: + self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator) + + def init_multispeaker(self, config: Coqpit, data: List = None): + """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer + or with external `d_vectors` computed from a speaker encoder model. + + If you need a different behaviour, override this function for your model. + + Args: + config (Coqpit): Model configuration. + data (List, optional): Dataset items to infer number of speakers. Defaults to None. + """ + if hasattr(config, "model_args"): + config = config.model_args + self.embedded_speaker_dim = 0 + # init speaker manager + self.speaker_manager = get_speaker_manager(config, data=data) + if config.num_speakers > 0 and self.speaker_manager.num_speakers == 0: + self.speaker_manager.num_speakers = config.num_speakers + self.num_speakers = self.speaker_manager.num_speakers + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + self.embedded_speaker_dim = config.speaker_embedding_channels + self.emb_g = nn.Embedding(config.num_speakers, config.speaker_embedding_channels) + # init d-vector usage + if config.use_d_vector_file: + self.embedded_speaker_dim = config.d_vector_dim + + @staticmethod + def _set_cond_input(aux_input: Dict): + """Set the speaker conditioning input based on the multi-speaker mode.""" + sid, g = None, None + if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: + sid = aux_input["speaker_ids"] + if sid.ndim == 0: + sid = sid.unsqueeze_(0) + if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: + g = aux_input["d_vectors"] + return sid, g + + def forward( + self, + x: torch.tensor, + x_lengths: torch.tensor, + y: torch.tensor, + y_lengths: torch.tensor, + aux_input={"d_vectors": None, "speaker_ids": None}, + ) -> Dict: + """Forward pass of the model. + + Args: + x (torch.tensor): Batch of input character sequence IDs. + x_lengths (torch.tensor): Batch of input character sequence lengths. + y (torch.tensor): Batch of input spectrograms. + y_lengths (torch.tensor): Batch of input spectrogram lengths. + aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}. + + Returns: + Dict: model outputs keyed by the output name. + + Shapes: + - x: :math:`[B, T_seq]` + - x_lengths: :math:`[B]` + - y: :math:`[B, C, T_spec]` + - y_lengths: :math:`[B]` + - d_vectors: :math:`[B, C, 1]` + - speaker_ids: :math:`[B]` + """ + outputs = {} + sid, g = self._set_cond_input(aux_input) + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths) + + # speaker embedding + if self.num_speakers > 1 and sid is not None: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + + # posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) + + # flow layers + z_p = self.flow(z, y_mask, g=g) + + # find the alignment path + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + with torch.no_grad(): + o_scale = torch.exp(-2 * logs_p) + # logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)]) + logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) + # logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp2 + logp3 + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + + # duration predictor + attn_durations = attn.sum(3) + if self.args.use_sdp: + nll_duration = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + attn_durations, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + ) + nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask)) + outputs["nll_duration"] = nll_duration + else: + attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask + log_durations = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + ) + loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) + outputs["loss_duration"] = loss_duration + + # expand prior + m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) + logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) + + # select a random feature segment for the waveform decoder + z_slice, slice_ids = rand_segment(z, y_lengths, self.spec_segment_size) + o = self.waveform_decoder(z_slice, g=g) + outputs.update( + { + "model_outputs": o, + "alignments": attn.squeeze(1), + "slice_ids": slice_ids, + "z": z, + "z_p": z_p, + "m_p": m_p, + "logs_p": logs_p, + "m_q": m_q, + "logs_q": logs_q, + } + ) + return outputs + + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): + """ + Shapes: + - x: :math:`[B, T_seq]` + - d_vectors: :math:`[B, C, 1]` + - speaker_ids: :math:`[B]` + """ + sid, g = self._set_cond_input(aux_input) + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths) + + if self.num_speakers > 0 and sid: + g = self.emb_g(sid).unsqueeze(-1) + + if self.args.use_sdp: + logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp) + else: + logw = self.duration_predictor(x, x_mask, g=g) + + w = torch.exp(logw) * x_mask * self.length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) + + m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) + logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) + + outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p} + return outputs + + def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): + """TODO: create an end-point for voice conversion""" + assert self.num_speakers > 0, "num_speakers have to be larger than 0." + g_src = self.emb_g(sid_src).unsqueeze(-1) + g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) + z, _, _, y_mask = self.enc_q(y, y_lengths, g=g_src) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) + return o_hat, y_mask, (z, z_p, z_hat) + + def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: + """Perform a single training step. Run the model forward pass and compute losses. + + Args: + batch (Dict): Input tensors. + criterion (nn.Module): Loss layer designed for the model. + optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. + + Returns: + Tuple[Dict, Dict]: Model ouputs and computed losses. + """ + # pylint: disable=attribute-defined-outside-init + if optimizer_idx not in [0, 1]: + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + if optimizer_idx == 0: + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_lengths = batch["mel_lengths"] + linear_input = batch["linear_input"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + waveform = batch["waveform"] + + # generator pass + outputs = self.forward( + text_input, + text_lengths, + linear_input.transpose(1, 2), + mel_lengths, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + + # cache tensors for the discriminator + self.y_disc_cache = None + self.wav_seg_disc_cache = None + self.y_disc_cache = outputs["model_outputs"] + wav_seg = segment( + waveform.transpose(1, 2), + outputs["slice_ids"] * self.config.audio.hop_length, + self.args.spec_segment_size * self.config.audio.hop_length, + ) + self.wav_seg_disc_cache = wav_seg + outputs["waveform_seg"] = wav_seg + + # compute discriminator scores and features + outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(outputs["model_outputs"]) + _, outputs["feats_disc_real"] = self.disc(wav_seg) + + # compute losses + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + waveform_hat=outputs["model_outputs"].float(), + waveform=wav_seg.float(), + z_p=outputs["z_p"].float(), + logs_q=outputs["logs_q"].float(), + m_p=outputs["m_p"].float(), + logs_p=outputs["logs_p"].float(), + z_len=mel_lengths, + scores_disc_fake=outputs["scores_disc_fake"], + feats_disc_fake=outputs["feats_disc_fake"], + feats_disc_real=outputs["feats_disc_real"], + ) + + # handle the duration loss + if self.args.use_sdp: + loss_dict["nll_duration"] = outputs["nll_duration"] + loss_dict["loss"] += outputs["nll_duration"] + else: + loss_dict["loss_duration"] = outputs["loss_duration"] + loss_dict["loss"] += outputs["nll_duration"] + + elif optimizer_idx == 1: + # discriminator pass + outputs = {} + + # compute scores and features + outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(self.y_disc_cache.detach()) + outputs["scores_disc_real"], outputs["feats_disc_real"] = self.disc(self.wav_seg_disc_cache) + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + outputs["scores_disc_real"], + outputs["scores_disc_fake"], + ) + return outputs, loss_dict + + def train_log( + self, ap: AudioProcessor, batch: Dict, outputs: List, name_prefix="train" + ): # pylint: disable=no-self-use + """Create visualizations and waveform examples. + + For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to + be projected onto Tensorboard. + + Args: + ap (AudioProcessor): audio processor used at training. + batch (Dict): Model inputs used at the previous training step. + outputs (Dict): Model outputs generated at the previoud training step. + + Returns: + Tuple[Dict, np.ndarray]: training plots and output waveform. + """ + y_hat = outputs[0]["model_outputs"] + y = outputs[0]["waveform_seg"] + figures = plot_results(y_hat, y, ap, name_prefix) + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + audios = {f"{name_prefix}/audio": sample_voice} + + alignments = outputs[0]["alignments"] + align_img = alignments[0].data.cpu().numpy().T + + figures.update( + { + "alignment": plot_alignment(align_img, output_fig=False), + } + ) + + return figures, audios + + @torch.no_grad() + def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + return self.train_step(batch, criterion, optimizer_idx) + + def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): + return self.train_log(ap, batch, outputs, "eval") + + @torch.no_grad() + def test_run(self, ap) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self.get_aux_input() + for idx, sen in enumerate(test_sentences): + wav, alignment, _, _ = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + ap, + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + enable_eos_bos_chars=self.config.enable_eos_bos_chars, + use_griffin_lim=True, + do_trim_silence=False, + ).values() + + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) + return test_figures, test_audios + + def get_optimizer(self) -> List: + """Initiate and return the GAN optimizers based on the config parameters. + + It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. + + Returns: + List: optimizers. + """ + self.disc.requires_grad_(False) + gen_parameters = filter(lambda p: p.requires_grad, self.parameters()) + self.disc.requires_grad_(True) + optimizer1 = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters + ) + optimizer2 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) + return [optimizer1, optimizer2] + + def get_lr(self) -> List: + """Set the initial learning rates for each optimizer. + + Returns: + List: learning rates for each optimizer. + """ + return [self.config.lr_gen, self.config.lr_disc] + + def get_scheduler(self, optimizer) -> List: + """Set the schedulers for each optimizer. + + Args: + optimizer (List[`torch.optim.Optimizer`]): List of optimizers. + + Returns: + List: Schedulers, one for each optimizer. + """ + scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler1, scheduler2] + + def get_criterion(self): + """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in + `train_step()`""" + from TTS.tts.layers.losses import ( # pylint: disable=import-outside-toplevel + VitsDiscriminatorLoss, + VitsGeneratorLoss, + ) + + return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)] + + @staticmethod + def make_symbols(config): + """Create a custom arrangement of symbols used by the model. The output list of symbols propagate along the + whole training and inference steps.""" + _pad = config.characters["pad"] + _punctuations = config.characters["punctuations"] + _letters = config.characters["characters"] + _letters_ipa = config.characters["phonemes"] + symbols = [_pad] + list(_punctuations) + list(_letters) + if config.use_phonemes: + symbols += list(_letters_ipa) + return symbols + + @staticmethod + def get_characters(config: Coqpit): + if config.characters is not None: + symbols = Vits.make_symbols(config) + else: + from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel + parse_symbols, + phonemes, + symbols, + ) + + config.characters = parse_symbols() + if config.use_phonemes: + symbols = phonemes + num_chars = len(symbols) + getattr(config, "add_blank", False) + return symbols, config, num_chars + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + """Load the model checkpoint and setup for training or inference""" + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py index 48f69374..d4345b64 100644 --- a/TTS/tts/utils/text/__init__.py +++ b/TTS/tts/utils/text/__init__.py @@ -81,7 +81,6 @@ def text2phone(text, language, use_espeak_phonemes=False): # Fix a few phonemes ph = ph.translate(GRUUT_TRANS_TABLE) - # print(" > Phonemes: {}".format(ph)) return ph raise ValueError(f" [!] Language {language} is not supported for phonemization.") @@ -116,6 +115,7 @@ def phoneme_to_sequence( use_espeak_phonemes: bool = False, ) -> List[int]: """Converts a string of phonemes to a sequence of IDs. + If `custom_symbols` is provided, it will override the default symbols. Args: text (str): string to convert to a sequence @@ -132,12 +132,11 @@ def phoneme_to_sequence( # pylint: disable=global-statement global _phonemes_to_id, _phonemes - if tp: - _, _phonemes = make_symbols(**tp) - _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} - elif custom_symbols is not None: + if custom_symbols is not None: _phonemes = custom_symbols - _phonemes_to_id = {s: i for i, s in enumerate(custom_symbols)} + elif tp: + _, _phonemes = make_symbols(**tp) + _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} sequence = [] clean_text = _clean_text(text, cleaner_names) @@ -155,16 +154,19 @@ def phoneme_to_sequence( return sequence -def sequence_to_phoneme(sequence, tp=None, add_blank=False): +def sequence_to_phoneme(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List["str"] = None): # pylint: disable=global-statement """Converts a sequence of IDs back to a string""" global _id_to_phonemes, _phonemes if add_blank: sequence = list(filter(lambda x: x != len(_phonemes), sequence)) result = "" - if tp: + + if custom_symbols is not None: + _phonemes = custom_symbols + elif tp: _, _phonemes = make_symbols(**tp) - _id_to_phonemes = {i: s for i, s in enumerate(_phonemes)} + _id_to_phonemes = {i: s for i, s in enumerate(_phonemes)} for symbol_id in sequence: if symbol_id in _id_to_phonemes: @@ -177,6 +179,7 @@ def text_to_sequence( text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False ) -> List[int]: """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + If `custom_symbols` is provided, it will override the default symbols. Args: text (str): string to convert to a sequence @@ -189,12 +192,12 @@ def text_to_sequence( """ # pylint: disable=global-statement global _symbol_to_id, _symbols - if tp: - _symbols, _ = make_symbols(**tp) - _symbol_to_id = {s: i for i, s in enumerate(_symbols)} - elif custom_symbols is not None: + + if custom_symbols is not None: _symbols = custom_symbols - _symbol_to_id = {s: i for i, s in enumerate(custom_symbols)} + elif tp: + _symbols, _ = make_symbols(**tp) + _symbol_to_id = {s: i for i, s in enumerate(_symbols)} sequence = [] @@ -213,16 +216,18 @@ def text_to_sequence( return sequence -def sequence_to_text(sequence, tp=None, add_blank=False): +def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List[str] = None): """Converts a sequence of IDs back to a string""" # pylint: disable=global-statement global _id_to_symbol, _symbols if add_blank: sequence = list(filter(lambda x: x != len(_symbols), sequence)) - if tp: + if custom_symbols is not None: + _symbols = custom_symbols + elif tp: _symbols, _ = make_symbols(**tp) - _id_to_symbol = {i: s for i, s in enumerate(_symbols)} + _id_to_symbol = {i: s for i, s in enumerate(_symbols)} result = "" for symbol_id in sequence: diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 1f21369f..40d82365 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -96,10 +96,12 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method ) self.mel_basis = torch.from_numpy(mel_basis).float() - def _amp_to_db(self, x, spec_gain=1.0): + @staticmethod + def _amp_to_db(x, spec_gain=1.0): return torch.log(torch.clamp(x, min=1e-5) * spec_gain) - def _db_to_amp(self, x, spec_gain=1.0): + @staticmethod + def _db_to_amp(x, spec_gain=1.0): return torch.exp(x) / spec_gain diff --git a/TTS/vocoder/models/hifigan_discriminator.py b/TTS/vocoder/models/hifigan_discriminator.py index a20c17b4..ca5eaf40 100644 --- a/TTS/vocoder/models/hifigan_discriminator.py +++ b/TTS/vocoder/models/hifigan_discriminator.py @@ -33,10 +33,10 @@ class DiscriminatorP(torch.nn.Module): norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm self.convs = nn.ModuleList( [ - norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), ] ) @@ -81,15 +81,15 @@ class MultiPeriodDiscriminator(torch.nn.Module): Periods are suggested to be prime numbers to reduce the overlap between each discriminator. """ - def __init__(self): + def __init__(self, use_spectral_norm=False): super().__init__() self.discriminators = nn.ModuleList( [ - DiscriminatorP(2), - DiscriminatorP(3), - DiscriminatorP(5), - DiscriminatorP(7), - DiscriminatorP(11), + DiscriminatorP(2, use_spectral_norm=use_spectral_norm), + DiscriminatorP(3, use_spectral_norm=use_spectral_norm), + DiscriminatorP(5, use_spectral_norm=use_spectral_norm), + DiscriminatorP(7, use_spectral_norm=use_spectral_norm), + DiscriminatorP(11, use_spectral_norm=use_spectral_norm), ] ) @@ -99,7 +99,7 @@ class MultiPeriodDiscriminator(torch.nn.Module): x (Tensor): input waveform. Returns: - [List[Tensor]]: list of scores from each discriminator. + [List[Tensor]]: list of scores from each discriminator. [List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer. Shapes: diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index 2260b781..a1e16150 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -170,6 +170,10 @@ class HifiganGenerator(torch.nn.Module): upsample_initial_channel, upsample_factors, inference_padding=5, + cond_channels=0, + conv_pre_weight_norm=True, + conv_post_weight_norm=True, + conv_post_bias=True, ): r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) @@ -218,12 +222,21 @@ class HifiganGenerator(torch.nn.Module): for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): self.resblocks.append(resblock(ch, k, d)) # post convolution layer - self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3)) + self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)) + if cond_channels > 0: + self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) - def forward(self, x): + if not conv_pre_weight_norm: + remove_weight_norm(self.conv_pre) + + if not conv_post_weight_norm: + remove_weight_norm(self.conv_post) + + def forward(self, x, g=None): """ Args: - x (Tensor): conditioning input tensor. + x (Tensor): feature input tensor. + g (Tensor): global conditioning input tensor. Returns: Tensor: output waveform. @@ -233,6 +246,8 @@ class HifiganGenerator(torch.nn.Module): Tensor: [B, 1, T] """ o = self.conv_pre(x) + if hasattr(self, "cond_layer"): + o = o + self.cond_layer(g) for i in range(self.num_upsamples): o = F.leaky_relu(o, LRELU_SLOPE) o = self.ups[i](o) diff --git a/TTS/vocoder/models/univnet_generator.py b/TTS/vocoder/models/univnet_generator.py index 0a6bd4c8..8a66c537 100644 --- a/TTS/vocoder/models/univnet_generator.py +++ b/TTS/vocoder/models/univnet_generator.py @@ -1,3 +1,5 @@ +from typing import List + import numpy as np import torch import torch.nn.functional as F @@ -10,18 +12,35 @@ LRELU_SLOPE = 0.1 class UnivnetGenerator(torch.nn.Module): def __init__( self, - in_channels, - out_channels, - hidden_channels, - cond_channels, - upsample_factors, - lvc_layers_each_block, - lvc_kernel_size, - kpnet_hidden_channels, - kpnet_conv_size, - dropout, + in_channels: int, + out_channels: int, + hidden_channels: int, + cond_channels: int, + upsample_factors: List[int], + lvc_layers_each_block: int, + lvc_kernel_size: int, + kpnet_hidden_channels: int, + kpnet_conv_size: int, + dropout: float, use_weight_norm=True, ): + """Univnet Generator network. + + Paper: https://arxiv.org/pdf/2106.07889.pdf + + Args: + in_channels (int): Number of input tensor channels. + out_channels (int): Number of channels of the output tensor. + hidden_channels (int): Number of hidden network channels. + cond_channels (int): Number of channels of the conditioning tensors. + upsample_factors (List[int]): List of uplsample factors for the upsampling layers. + lvc_layers_each_block (int): Number of LVC layers in each block. + lvc_kernel_size (int): Kernel size of the LVC layers. + kpnet_hidden_channels (int): Number of hidden channels in the key-point network. + kpnet_conv_size (int): Number of convolution channels in the key-point network. + dropout (float): Dropout rate. + use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True. + """ super().__init__() self.in_channels = in_channels diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index eeabbea5..63a0af44 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -1,8 +1,11 @@ +from typing import Dict + import numpy as np import torch from matplotlib import pyplot as plt from TTS.tts.utils.visual import plot_spectrogram +from TTS.utils.audio import AudioProcessor def interpolate_vocoder_input(scale_factor, spec): @@ -26,12 +29,24 @@ def interpolate_vocoder_input(scale_factor, spec): return spec -def plot_results(y_hat, y, ap, name_prefix): - """Plot vocoder model results""" +def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict: + """Plot the predicted and the real waveform and their spectrograms. + + Args: + y_hat (torch.tensor): Predicted waveform. + y (torch.tensor): Real waveform. + ap (AudioProcessor): Audio processor used to process the waveform. + name_prefix (str, optional): Name prefix used to name the figures. Defaults to None. + + Returns: + Dict: output figures keyed by the name of the figures. + """ """Plot vocoder model results""" + if name_prefix is None: + name_prefix = "" # select an instance from batch - y_hat = y_hat[0].squeeze(0).detach().cpu().numpy() - y = y[0].squeeze(0).detach().cpu().numpy() + y_hat = y_hat[0].squeeze().detach().cpu().numpy() + y = y[0].squeeze().detach().cpu().numpy() spec_fake = ap.melspectrogram(y_hat).T spec_real = ap.melspectrogram(y).T diff --git a/tests/tts_tests/test_vits_train.py b/tests/tts_tests/test_vits_train.py new file mode 100644 index 00000000..db9d2fc1 --- /dev/null +++ b/tests/tts_tests/test_vits_train.py @@ -0,0 +1,54 @@ +import glob +import os +import shutil + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs import VitsConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = VitsConfig( + batch_size=2, + eval_batch_size=2, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + use_espeak_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + "Be a voice, not an echo.", + ], +) +config.audio.do_trim_silence = True +config.audio.trim_db = 60 +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path)