From 3b8b105b0d6539ac12972de94e0b2a5077fa1ce2 Mon Sep 17 00:00:00 2001 From: Shivam Mehta Date: Mon, 12 Dec 2022 12:44:15 +0100 Subject: [PATCH] Adding OverFlow (#2183) * Adding encoder * currently modifying hmm * Adding hmm * Adding overflow * Adding overflow setting up flat start * Removing runs * adding normalization parameters * Fixing models on same device * Training overflow and plotting evaluations * Adding inference * At the end of epoch the test sentences are coming on cpu instead of gpu * Adding figures from model during training to monitor * reverting tacotron2 training recipe * fixing inference on gpu for test sentences on config * moving helpers and texts within overflows source code * renaming to overflow * moving loss to the model file * Fixing the rename * Model training but not plotting the test config sentences's audios * Formatting logs * Changing model name to camelcase * Fixing test log * Fixing plotting bug * Adding some tests * Adding more tests to overflow * Adding all tests for overflow * making changes to camel case in config * Adding information about parameters and docstring * removing compute_mel_statistics moved statistic computation to the model instead * Added overflow in readme * Adding more test cases, now it doesn't saves transition_p like tensor and can be dumped as json --- .pre-commit-config.yaml | 2 +- README.md | 1 + TTS/tts/configs/overflow_config.py | 200 +++++++ TTS/tts/configs/shared_configs.py | 2 +- TTS/tts/layers/overflow/__init__.py | 0 TTS/tts/layers/overflow/common_layers.py | 324 ++++++++++++ TTS/tts/layers/overflow/decoder.py | 82 +++ TTS/tts/layers/overflow/neural_hmm.py | 555 ++++++++++++++++++++ TTS/tts/layers/overflow/plotting_utils.py | 74 +++ TTS/tts/models/overflow.py | 393 ++++++++++++++ TTS/tts/utils/visual.py | 7 +- recipes/ljspeech/overflow/lj_parameters.pt | Bin 0 -> 507 bytes recipes/ljspeech/overflow/train_overflow.py | 96 ++++ tests/tts_tests/test_overflow.py | 399 ++++++++++++++ tests/tts_tests/test_overflow_train.py | 92 ++++ 15 files changed, 2223 insertions(+), 4 deletions(-) create mode 100644 TTS/tts/configs/overflow_config.py create mode 100644 TTS/tts/layers/overflow/__init__.py create mode 100644 TTS/tts/layers/overflow/common_layers.py create mode 100644 TTS/tts/layers/overflow/decoder.py create mode 100644 TTS/tts/layers/overflow/neural_hmm.py create mode 100644 TTS/tts/layers/overflow/plotting_utils.py create mode 100644 TTS/tts/models/overflow.py create mode 100644 recipes/ljspeech/overflow/lj_parameters.pt create mode 100644 recipes/ljspeech/overflow/train_overflow.py create mode 100644 tests/tts_tests/test_overflow.py create mode 100644 tests/tts_tests/test_overflow_train.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a70572dc..911f2a83 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: 'https://github.com/psf/black' - rev: 20.8b1 + rev: 22.3.0 hooks: - id: black language_version: python3 diff --git a/README.md b/README.md index cb0f7344..3c960af6 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,7 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models - FastSpeech: [paper](https://arxiv.org/abs/1905.09263) - SC-GlowTTS: [paper](https://arxiv.org/abs/2104.05557) - Capacitron: [paper](https://arxiv.org/abs/1906.03402) +- OverFlow: [paper](https://arxiv.org/abs/2211.06892) ### End-to-End Models - VITS: [paper](https://arxiv.org/pdf/2106.06103) diff --git a/TTS/tts/configs/overflow_config.py b/TTS/tts/configs/overflow_config.py new file mode 100644 index 00000000..ffa0a8a0 --- /dev/null +++ b/TTS/tts/configs/overflow_config.py @@ -0,0 +1,200 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig + + +@dataclass +class OverflowConfig(BaseTTSConfig): # The classname has to be camel case + """ + Define parameters for OverFlow model. + + Example: + + >>> from TTS.tts.configs.overflow_config import OverflowConfig + >>> config = OverflowConfig() + + Args: + model (str): + Model name used to select the right model class to initilize. Defaults to `Overflow`. + run_eval_steps (int): + Run evalulation epoch after N steps. If None, waits until training epoch is completed. Defaults to None. + save_step (int): + Save local checkpoint every save_step steps. Defaults to 500. + plot_step (int): + Plot training stats on the logger every plot_step steps. Defaults to 1. + model_param_stats (bool): + Log model parameters stats on the logger dashboard. Defaults to False. + force_generate_statistics (bool): + Force generate mel normalization statistics. Defaults to False. + mel_statistics_parameter_path (str): + Path to the mel normalization statistics.If the model doesn't finds a file there it will generate statistics. + Defaults to None. + num_chars (int): + Number of characters used by the model. It must be defined before initializing the model. Defaults to None. + state_per_phone (int): + Generates N states per phone. Similar, to `add_blank` parameter in GlowTTS but in Overflow it is upsampled by model's encoder. Defaults to 2. + encoder_in_out_features (int): + Channels of encoder input and character embedding tensors. Defaults to 512. + encoder_n_convolutions (int): + Number of convolution layers in the encoder. Defaults to 3. + out_channels (int): + Channels of the final model output. It must match the spectragram size. Defaults to 80. + ar_order (int): + Autoregressive order of the model. Defaults to 1. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio. + sampling_temp (float): + Variation added to the sample from the latent space of neural HMM. Defaults to 0.334. + deterministic_transition (bool): + deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True. + duration_threshold (float): + Threshold for duration quantiles. Defaults to 0.55. Tune this to change the speaking rate of the synthesis, where lower values defines a slower speaking rate and higher values defines a faster speaking rate. + use_grad_checkpointing (bool): + Use gradient checkpointing to save memory. In a multi-GPU setting currently pytorch does not supports gradient checkpoint inside a loop so we will have to turn it off then.Adjust depending on whatever get more batch size either by using a single GPU or multi-GPU. Defaults to True. + max_sampling_time (int): + Maximum sampling time while synthesising latents from neural HMM. Defaults to 1000. + prenet_type (str): + `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the + Prenet. Defaults to `original`. + prenet_dim (int): + Dimension of the Prenet. Defaults to 256. + prenet_n_layers (int): + Number of layers in the Prenet. Defaults to 2. + prenet_dropout (float): + Dropout rate of the Prenet. Defaults to 0.5. + prenet_dropout_at_inference (bool): + Use dropout at inference time. Defaults to False. + memory_rnn_dim (int): + Dimension of the memory LSTM to process the prenet output. Defaults to 1024. + outputnet_size (list[int]): + Size of the output network inside the neural HMM. Defaults to [1024]. + flat_start_params (dict): + Parameters for the flat start initialization of the neural HMM. Defaults to `{"mean": 0.0, "std": 1.0, "transition_p": 0.14}`. + It will be recomputed when you pass the dataset. + std_floor (float): + Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint. Defaults to 0.01. + It is called `variance flooring` in standard HMM literature. + hidden_channels_dec (int): + Number of base hidden channels used by the decoder WaveNet network. Defaults to 150. + kernel_size_dec (int): + Decoder kernel size. Defaults to 5 + dilation_rate (int): + Rate to increase dilation by each layer in a decoder block. Defaults to 1. + num_flow_blocks_dec (int): + Number of decoder layers in each decoder block. Defaults to 4. + dropout_p_dec (float): + Dropout rate of the decoder. Defaults to 0.05. + num_splits (int): + Number of split levels in inversible conv1x1 operation. Defaults to 4. + num_squeeze (int): + Number of squeeze levels. When squeezing channels increases and time steps reduces by the factor + 'num_squeeze'. Defaults to 2. + sigmoid_scale (bool): + enable/disable sigmoid scaling in decoder. Defaults to False. + c_in_channels (int): + Unused parameter from GlowTTS's decoder. Defaults to 0. + optimizer (str): + Optimizer to use for training. Defaults to `adam`. + optimizer_params (dict): + Parameters for the optimizer. Defaults to `{"weight_decay": 1e-6}`. + grad_clip (float): + Gradient clipping threshold. Defaults to 40_000. + lr (float): + Learning rate. Defaults to 1e-3. + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to `None`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "Overflow" + + # Training and Checkpoint configs + run_eval_steps: int = 100 + save_step: int = 500 + plot_step: int = 1 + model_param_stats: bool = False + + # data parameters + force_generate_statistics: bool = False + mel_statistics_parameter_path: str = None + + # Encoder parameters + num_chars: int = None + state_per_phone: int = 2 + encoder_in_out_features: int = 512 + encoder_n_convolutions: int = 3 + + # HMM parameters + out_channels: int = 80 + ar_order: int = 1 + sampling_temp: float = 0.334 + deterministic_transition: bool = True + duration_threshold: float = 0.55 + use_grad_checkpointing: bool = True + max_sampling_time: int = 1000 + + ## Prenet parameters + prenet_type: str = "original" + prenet_dim: int = 256 + prenet_n_layers: int = 2 + prenet_dropout: float = 0.5 + prenet_dropout_at_inference: bool = False + memory_rnn_dim: int = 1024 + + ## Outputnet parameters + outputnet_size: List[int] = field(default_factory=lambda: [1024]) + flat_start_params: dict = field(default_factory=lambda: {"mean": 0.0, "std": 1.0, "transition_p": 0.14}) + std_floor: float = 0.01 + + # Decoder parameters + hidden_channels_dec: int = 150 + kernel_size_dec: int = 5 + dilation_rate: int = 1 + num_flow_blocks_dec: int = 12 + num_block_layers: int = 4 + dropout_p_dec: float = 0.05 + num_splits: int = 4 + num_squeeze: int = 2 + sigmoid_scale: bool = False + c_in_channels: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"weight_decay": 1e-6}) + grad_clip: float = 40000.0 + lr: float = 1e-3 + lr_scheduler: str = None + + # overrides + min_seq_len: int = 3 + max_seq_len: int = 500 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "Be a voice, not an echo.", + ] + ) + + # Extra needed config + r: int = 1 + use_d_vector_file: bool = False + use_speaker_embedding: bool = False + + def check_values(self): + """Validate the hyperparameters. + + Raises: + AssertionError: when the parameters network is not defined + AssertionError: transition probability is not between 0 and 1 + """ + assert self.ar_order > 0, "AR order must be greater than 0 it is an autoregressive model." + assert ( + len(self.outputnet_size) >= 1 + ), f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}" + assert ( + 0 < self.flat_start_params["transition_p"] < 1 + ), f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}" diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index e1ea8be3..676ea4a9 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -315,7 +315,7 @@ class BaseTTSConfig(BaseTrainingConfig): optimizer: str = "radam" optimizer_params: dict = None # scheduler - lr_scheduler: str = "" + lr_scheduler: str = None lr_scheduler_params: dict = field(default_factory=lambda: {}) # testing test_sentences: List[str] = field(default_factory=lambda: []) diff --git a/TTS/tts/layers/overflow/__init__.py b/TTS/tts/layers/overflow/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/tts/layers/overflow/common_layers.py b/TTS/tts/layers/overflow/common_layers.py new file mode 100644 index 00000000..ba9a139e --- /dev/null +++ b/TTS/tts/layers/overflow/common_layers.py @@ -0,0 +1,324 @@ +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from tqdm.auto import tqdm + +from TTS.tts.layers.tacotron.common_layers import Linear +from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock + + +class Encoder(nn.Module): + r"""Neural HMM Encoder + + Same as Tacotron 2 encoder but increases the input length by states per phone + + Args: + num_chars (int): Number of characters in the input. + state_per_phone (int): Number of states per phone. + in_out_channels (int): number of input and output channels. + n_convolutions (int): number of convolutional layers. + """ + + def __init__(self, num_chars, state_per_phone, in_out_channels=512, n_convolutions=3): + + super().__init__() + + self.state_per_phone = state_per_phone + self.in_out_channels = in_out_channels + + self.emb = nn.Embedding(num_chars, in_out_channels) + self.convolutions = nn.ModuleList() + for _ in range(n_convolutions): + self.convolutions.append(ConvBNBlock(in_out_channels, in_out_channels, 5, "relu")) + self.lstm = nn.LSTM( + in_out_channels, + int(in_out_channels / 2) * state_per_phone, + num_layers=1, + batch_first=True, + bias=True, + bidirectional=True, + ) + self.rnn_state = None + + def forward(self, x: torch.FloatTensor, x_len: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]: + """Forward pass to the encoder. + + Args: + x (torch.FloatTensor): input text indices. + - shape: :math:`(b, T_{in})` + x_len (torch.LongTensor): input text lengths. + - shape: :math:`(b,)` + + Returns: + Tuple[torch.FloatTensor, torch.LongTensor]: encoder outputs and output lengths. + -shape: :math:`((b, T_{in} * states_per_phone, in_out_channels), (b,))` + """ + b, T = x.shape + o = self.emb(x).transpose(1, 2) + for layer in self.convolutions: + o = layer(o) + o = o.transpose(1, 2) + o = nn.utils.rnn.pack_padded_sequence(o, x_len.cpu(), batch_first=True) + self.lstm.flatten_parameters() + o, _ = self.lstm(o) + o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True) + o = o.reshape(b, T * self.state_per_phone, self.in_out_channels) + x_len = x_len * self.state_per_phone + return o, x_len + + def inference(self, x, x_len): + """Inference to the encoder. + + Args: + x (torch.FloatTensor): input text indices. + - shape: :math:`(b, T_{in})` + x_len (torch.LongTensor): input text lengths. + - shape: :math:`(b,)` + + Returns: + Tuple[torch.FloatTensor, torch.LongTensor]: encoder outputs and output lengths. + -shape: :math:`((b, T_{in} * states_per_phone, in_out_channels), (b,))` + """ + b, T = x.shape + o = self.emb(x).transpose(1, 2) + for layer in self.convolutions: + o = layer(o) + o = o.transpose(1, 2) + # self.lstm.flatten_parameters() + o, _ = self.lstm(o) + o = o.reshape(b, T * self.state_per_phone, self.in_out_channels) + x_len = x_len * self.state_per_phone + return o, x_len + + +class ParameterModel(nn.Module): + r"""Main neural network of the outputnet + + Note: Do not put dropout layers here, the model will not converge. + + Args: + outputnet_size (List[int]): the architecture of the parameter model + input_size (int): size of input for the first layer + output_size (int): size of output i.e size of the feature dim + frame_channels (int): feature dim to set the flat start bias + flat_start_params (dict): flat start parameters to set the bias + """ + + def __init__( + self, + outputnet_size: List[int], + input_size: int, + output_size: int, + frame_channels: int, + flat_start_params: dict, + ): + super().__init__() + self.frame_channels = frame_channels + + self.layers = nn.ModuleList( + [Linear(inp, out) for inp, out in zip([input_size] + outputnet_size[:-1], outputnet_size)] + ) + self.last_layer = nn.Linear(outputnet_size[-1], output_size) + self.flat_start_output_layer( + flat_start_params["mean"], flat_start_params["std"], flat_start_params["transition_p"] + ) + + def flat_start_output_layer(self, mean, std, transition_p): + self.last_layer.weight.data.zero_() + self.last_layer.bias.data[0 : self.frame_channels] = mean + self.last_layer.bias.data[self.frame_channels : 2 * self.frame_channels] = OverflowUtils.inverse_softplus(std) + self.last_layer.bias.data[2 * self.frame_channels :] = OverflowUtils.inverse_sigmod(transition_p) + + def forward(self, x): + for layer in self.layers: + x = F.relu(layer(x)) + x = self.last_layer(x) + return x + + +class Outputnet(nn.Module): + r""" + This network takes current state and previous observed values as input + and returns its parameters, mean, standard deviation and probability + of transition to the next state + """ + + def __init__( + self, + encoder_dim: int, + memory_rnn_dim: int, + frame_channels: int, + outputnet_size: List[int], + flat_start_params: dict, + std_floor: float = 1e-2, + ): + super().__init__() + + self.frame_channels = frame_channels + self.flat_start_params = flat_start_params + self.std_floor = std_floor + + input_size = memory_rnn_dim + encoder_dim + output_size = 2 * frame_channels + 1 + + self.parametermodel = ParameterModel( + outputnet_size=outputnet_size, + input_size=input_size, + output_size=output_size, + flat_start_params=flat_start_params, + frame_channels=frame_channels, + ) + + def forward(self, ar_mels, inputs): + r"""Inputs observation and returns the means, stds and transition probability for the current state + + Args: + ar_mel_inputs (torch.FloatTensor): shape (batch, prenet_dim) + states (torch.FloatTensor): (batch, hidden_states, hidden_state_dim) + + Returns: + means: means for the emission observation for each feature + - shape: (B, hidden_states, feature_size) + stds: standard deviations for the emission observation for each feature + - shape: (batch, hidden_states, feature_size) + transition_vectors: transition vector for the current hidden state + - shape: (batch, hidden_states) + """ + batch_size, prenet_dim = ar_mels.shape[0], ar_mels.shape[1] + N = inputs.shape[1] + + ar_mels = ar_mels.unsqueeze(1).expand(batch_size, N, prenet_dim) + ar_mels = torch.cat((ar_mels, inputs), dim=2) + ar_mels = self.parametermodel(ar_mels) + + mean, std, transition_vector = ( + ar_mels[:, :, 0 : self.frame_channels], + ar_mels[:, :, self.frame_channels : 2 * self.frame_channels], + ar_mels[:, :, 2 * self.frame_channels :].squeeze(2), + ) + std = F.softplus(std) + std = self._floor_std(std) + return mean, std, transition_vector + + def _floor_std(self, std): + r""" + It clamps the standard deviation to not to go below some level + This removes the problem when the model tries to cheat for higher likelihoods by converting + one of the gaussians to a point mass. + + Args: + std (float Tensor): tensor containing the standard deviation to be + """ + original_tensor = std.clone().detach() + std = torch.clamp(std, min=self.std_floor) + if torch.any(original_tensor != std): + print( + "[*] Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about" + ) + return std + + +class OverflowUtils: + @staticmethod + def get_data_parameters_for_flat_start( + data_loader: torch.utils.data.DataLoader, out_channels: int, states_per_phone: int + ): + """Generates data parameters for flat starting the HMM. + + Args: + data_loader (torch.utils.data.Dataloader): _description_ + out_channels (int): mel spectrogram channels + states_per_phone (_type_): HMM states per phone + """ + + # State related information for transition_p + total_state_len = 0 + total_mel_len = 0 + + # Useful for data mean an std + total_mel_sum = 0 + total_mel_sq_sum = 0 + + for batch in tqdm(data_loader, leave=False): + text_lengths = batch["token_id_lengths"] + mels = batch["mel"] + mel_lengths = batch["mel_lengths"] + + total_state_len += torch.sum(text_lengths) + total_mel_len += torch.sum(mel_lengths) + total_mel_sum += torch.sum(mels) + total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) + + data_mean = total_mel_sum / (total_mel_len * out_channels) + data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) + average_num_states = total_state_len / len(data_loader.dataset) + average_mel_len = total_mel_len / len(data_loader.dataset) + average_duration_each_state = average_mel_len / average_num_states + init_transition_prob = 1 / average_duration_each_state + + return data_mean, data_std, (init_transition_prob * states_per_phone) + + @staticmethod + @torch.no_grad() + def update_flat_start_transition(model, transition_p): + model.neural_hmm.output_net.parametermodel.flat_start_output_layer(0.0, 1.0, transition_p) + + @staticmethod + def log_clamped(x, eps=1e-04): + """ + Avoids the log(0) problem + + Args: + x (torch.tensor): input tensor + eps (float, optional): lower bound. Defaults to 1e-04. + + Returns: + torch.tensor: :math:`log(x)` + """ + clamped_x = torch.clamp(x, min=eps) + return torch.log(clamped_x) + + @staticmethod + def inverse_sigmod(x): + r""" + Inverse of the sigmoid function + """ + if not torch.is_tensor(x): + x = torch.tensor(x) + return OverflowUtils.log_clamped(x / (1.0 - x)) + + @staticmethod + def inverse_softplus(x): + r""" + Inverse of the softplus function + """ + if not torch.is_tensor(x): + x = torch.tensor(x) + return OverflowUtils.log_clamped(torch.exp(x) - 1.0) + + @staticmethod + def logsumexp(x, dim): + r""" + Differentiable LogSumExp: Does not creates nan gradients + when all the inputs are -inf yeilds 0 gradients. + Args: + x : torch.Tensor - The input tensor + dim: int - The dimension on which the log sum exp has to be applied + """ + + m, _ = x.max(dim=dim) + mask = m == -float("inf") + s = (x - m.masked_fill_(mask, 0).unsqueeze(dim=dim)).exp().sum(dim=dim) + return s.masked_fill_(mask, 1).log() + m.masked_fill_(mask, -float("inf")) + + @staticmethod + def double_pad(list_of_different_shape_tensors): + r""" + Pads the list of tensors in 2 dimensions + """ + second_dim_lens = [len(a) for a in [i[0] for i in list_of_different_shape_tensors]] + second_dim_max = max(second_dim_lens) + padded_x = [F.pad(x, (0, second_dim_max - len(x[0]))) for x in list_of_different_shape_tensors] + return nn.utils.rnn.pad_sequence(padded_x, batch_first=True) diff --git a/TTS/tts/layers/overflow/decoder.py b/TTS/tts/layers/overflow/decoder.py new file mode 100644 index 00000000..4e65993f --- /dev/null +++ b/TTS/tts/layers/overflow/decoder.py @@ -0,0 +1,82 @@ +import torch +from torch import nn + +from TTS.tts.layers.glow_tts.decoder import Decoder as GlowDecoder +from TTS.tts.utils.helpers import sequence_mask + + +class Decoder(nn.Module): + """Uses glow decoder with some modifications. + :: + + Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze + + Args: + in_channels (int): channels of input tensor. + hidden_channels (int): hidden decoder channels. + kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.) + dilation_rate (int): rate to increase dilation by each layer in a decoder block. + num_flow_blocks (int): number of decoder blocks. + num_coupling_layers (int): number coupling layers. (number of wavenet layers.) + dropout_p (float): wavenet dropout rate. + sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer. + """ + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_flow_blocks, + num_coupling_layers, + dropout_p=0.0, + num_splits=4, + num_squeeze=2, + sigmoid_scale=False, + c_in_channels=0, + ): + + super().__init__() + + self.glow_decoder = GlowDecoder( + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_flow_blocks, + num_coupling_layers, + dropout_p, + num_splits, + num_squeeze, + sigmoid_scale, + c_in_channels, + ) + self.n_sqz = num_squeeze + + def forward(self, x, x_len, g=None, reverse=False): + """ + Input shapes: + - x: :math:`[B, C, T]` + - x_len :math:`[B]` + - g: :math:`[B, C]` + + Output shapes: + - x: :math:`[B, C, T]` + - x_len :math:`[B]` + - logget_tot :math:`[B]` + """ + x, x_len, x_max_len = self.preprocess(x, x_len, x_len.max()) + x_mask = torch.unsqueeze(sequence_mask(x_len, x_max_len), 1).to(x.dtype) + x, logdet_tot = self.glow_decoder(x, x_mask, g, reverse) + return x, x_len, logdet_tot + + def preprocess(self, y, y_lengths, y_max_length): + if y_max_length is not None: + y_max_length = torch.div(y_max_length, self.n_sqz, rounding_mode="floor") * self.n_sqz + y = y[:, :, :y_max_length] + y_lengths = torch.div(y_lengths, self.n_sqz, rounding_mode="floor") * self.n_sqz + return y, y_lengths, y_max_length + + def store_inverse(self): + self.glow_decoder.store_inverse() diff --git a/TTS/tts/layers/overflow/neural_hmm.py b/TTS/tts/layers/overflow/neural_hmm.py new file mode 100644 index 00000000..3aa59a0b --- /dev/null +++ b/TTS/tts/layers/overflow/neural_hmm.py @@ -0,0 +1,555 @@ +from typing import List + +import torch +import torch.distributions as tdist +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +from TTS.tts.layers.overflow.common_layers import Outputnet, OverflowUtils +from TTS.tts.layers.tacotron.common_layers import Prenet +from TTS.tts.utils.helpers import sequence_mask + + +class NeuralHMM(nn.Module): + """Autoregressive left to right HMM model primarily used in "Neural HMMs are all you need (for high-quality attention-free TTS)" + + Paper:: + https://arxiv.org/abs/2108.13320 + + Paper abstract:: + Neural sequence-to-sequence TTS has achieved significantly better output quality than statistical speech synthesis using + HMMs. However, neural TTS is generally not probabilistic and uses non-monotonic attention. Attention failures increase + training time and can make synthesis babble incoherently. This paper describes how the old and new paradigms can be + combined to obtain the advantages of both worlds, by replacing attention in neural TTS with an autoregressive left-right + no-skip hidden Markov model defined by a neural network. Based on this proposal, we modify Tacotron 2 to obtain an + HMM-based neural TTS model with monotonic alignment, trained to maximise the full sequence likelihood without + approximation. We also describe how to combine ideas from classical and contemporary TTS for best results. The resulting + example system is smaller and simpler than Tacotron 2, and learns to speak with fewer iterations and less data, whilst + achieving comparable naturalness prior to the post-net. Our approach also allows easy control over speaking rate. + + Args: + frame_channels (int): Output dimension to generate. + ar_order (int): Autoregressive order of the model. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio. + deterministic_transition (bool): deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True. + encoder_dim (int): Channels of encoder input and character embedding tensors. Defaults to 512. + prenet_type (str): `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the Prenet. + prenet_dim (int): Dimension of the Prenet. + prenet_n_layers (int): Number of layers in the Prenet. + prenet_dropout (float): Dropout probability of the Prenet. + prenet_dropout_at_inference (bool): If True, dropout is applied at inference time. + memory_rnn_dim (int): Size of the memory RNN to process output of prenet. + outputnet_size (List[int]): Size of the output network inside the neural HMM. + flat_start_params (dict): Parameters for the flat start initialization of the neural HMM. + std_floor (float): Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint. + use_grad_checkpointing (bool, optional): Use gradient checkpointing to save memory. Defaults to True. + """ + + def __init__( + self, + frame_channels: int, + ar_order: int, + deterministic_transition: bool, + encoder_dim: int, + prenet_type: str, + prenet_dim: int, + prenet_n_layers: int, + prenet_dropout: float, + prenet_dropout_at_inference: bool, + memory_rnn_dim: int, + outputnet_size: List[int], + flat_start_params: dict, + std_floor: float, + use_grad_checkpointing: bool = True, + ): + super().__init__() + + self.frame_channels = frame_channels + self.ar_order = ar_order + self.deterministic_transition = deterministic_transition + self.prenet_dim = prenet_dim + self.memory_rnn_dim = memory_rnn_dim + self.use_grad_checkpointing = use_grad_checkpointing + + self.transition_model = TransitionModel() + self.emission_model = EmissionModel() + + assert ar_order > 0, f"AR order must be greater than 0 provided {ar_order}" + + self.ar_order = ar_order + self.prenet = Prenet( + in_features=frame_channels * ar_order, + prenet_type=prenet_type, + prenet_dropout=prenet_dropout, + dropout_at_inference=prenet_dropout_at_inference, + out_features=[self.prenet_dim for _ in range(prenet_n_layers)], + bias=False, + ) + self.memory_rnn = nn.LSTMCell(input_size=prenet_dim, hidden_size=memory_rnn_dim) + self.output_net = Outputnet( + encoder_dim, memory_rnn_dim, frame_channels, outputnet_size, flat_start_params, std_floor + ) + self.register_buffer("go_tokens", torch.zeros(ar_order, 1)) + + def forward(self, inputs, inputs_len, mels, mel_lens): + r"""HMM forward algorithm for training uses logarithmic version of Rabiner (1989) forward algorithm. + + Args: + inputs (torch.FloatTensor): Encoder outputs + inputs_len (torch.LongTensor): Encoder output lengths + mels (torch.FloatTensor): Mel inputs + mel_lens (torch.LongTensor): Length of mel inputs + + Shapes: + - inputs: (B, T, D_out_enc) + - inputs_len: (B) + - mels: (B, D_mel, T_mel) + - mel_lens: (B) + + Returns: + log_prob (torch.FloatTensor): Log probability of the sequence + """ + # Get dimensions of inputs + batch_size, N, _ = inputs.shape + T_max = torch.max(mel_lens) + mels = mels.permute(0, 2, 1) + + # Intialize forward algorithm + log_state_priors = self._initialize_log_state_priors(inputs) + log_c, log_alpha_scaled, transition_matrix, means = self._initialize_forward_algorithm_variables(mels, N) + + # Initialize autoregression elements + ar_inputs = self._add_go_token(mels) + h_memory, c_memory = self._init_lstm_states(batch_size, self.memory_rnn_dim, mels) + + for t in range(T_max): + + # Process Autoregression + h_memory, c_memory = self._process_ar_timestep(t, ar_inputs, h_memory, c_memory) + # Get mean, std and transition vector from decoder for this timestep + # Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop + if self.use_grad_checkpointing and self.training: + mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs) + else: + mean, std, transition_vector = self.output_net(h_memory, inputs) + + if t == 0: + log_alpha_temp = log_state_priors + self.emission_model(mels[:, 0], mean, std, inputs_len) + else: + log_alpha_temp = self.emission_model(mels[:, t], mean, std, inputs_len) + self.transition_model( + log_alpha_scaled[:, t - 1, :], transition_vector, inputs_len + ) + log_c[:, t] = torch.logsumexp(log_alpha_temp, dim=1) + log_alpha_scaled[:, t, :] = log_alpha_temp - log_c[:, t].unsqueeze(1) + transition_matrix[:, t] = transition_vector # needed for absorption state calculation + + # Save for plotting + means.append(mean.detach()) + + log_c, log_alpha_scaled = self._mask_lengths(mel_lens, log_c, log_alpha_scaled) + + sum_final_log_c = self.get_absorption_state_scaling_factor( + mel_lens, log_alpha_scaled, inputs_len, transition_matrix + ) + + log_probs = torch.sum(log_c, dim=1) + sum_final_log_c + + return log_probs, log_alpha_scaled, transition_matrix, means + + @staticmethod + def _mask_lengths(mel_lens, log_c, log_alpha_scaled): + """ + Mask the lengths of the forward variables so that the variable lenghts + do not contribute in the loss calculation + Args: + mel_inputs (torch.FloatTensor): (batch, T, frame_channels) + mel_inputs_lengths (torch.IntTensor): (batch) + log_c (torch.FloatTensor): (batch, T) + Returns: + log_c (torch.FloatTensor) : scaled probabilities (batch, T) + log_alpha_scaled (torch.FloatTensor): forward probabilities (batch, T, N) + """ + mask_log_c = sequence_mask(mel_lens) + log_c = log_c * mask_log_c + mask_log_alpha_scaled = mask_log_c.unsqueeze(2) + log_alpha_scaled = log_alpha_scaled * mask_log_alpha_scaled + return log_c, log_alpha_scaled + + def _process_ar_timestep( + self, + t, + ar_inputs, + h_memory, + c_memory, + ): + """ + Process autoregression in timestep + 1. At a specific t timestep + 2. Perform data dropout if applied (we did not use it) + 3. Run the autoregressive frame through the prenet (has dropout) + 4. Run the prenet output through the post prenet rnn + + Args: + t (int): mel-spec timestep + ar_inputs (torch.FloatTensor): go-token appended mel-spectrograms + - shape: (b, D_out, T_out) + h_post_prenet (torch.FloatTensor): previous timestep rnn hidden state + - shape: (b, memory_rnn_dim) + c_post_prenet (torch.FloatTensor): previous timestep rnn cell state + - shape: (b, memory_rnn_dim) + + Returns: + h_post_prenet (torch.FloatTensor): rnn hidden state of the current timestep + c_post_prenet (torch.FloatTensor): rnn cell state of the current timestep + """ + prenet_input = ar_inputs[:, t : t + self.ar_order].flatten(1) + memory_inputs = self.prenet(prenet_input) + h_memory, c_memory = self.memory_rnn(memory_inputs, (h_memory, c_memory)) + return h_memory, c_memory + + def _add_go_token(self, mel_inputs): + """Append the go token to create the autoregressive input + Args: + mel_inputs (torch.FloatTensor): (batch_size, T, n_mel_channel) + Returns: + ar_inputs (torch.FloatTensor): (batch_size, T, n_mel_channel) + """ + batch_size, T, _ = mel_inputs.shape + go_tokens = self.go_tokens.unsqueeze(0).expand(batch_size, self.ar_order, self.frame_channels) + ar_inputs = torch.cat((go_tokens, mel_inputs), dim=1)[:, :T] + return ar_inputs + + @staticmethod + def _initialize_forward_algorithm_variables(mel_inputs, N): + r"""Initialize placeholders for forward algorithm variables, to use a stable + version we will use log_alpha_scaled and the scaling constant + + Args: + mel_inputs (torch.FloatTensor): (b, T_max, frame_channels) + N (int): number of states + Returns: + log_c (torch.FloatTensor): Scaling constant (b, T_max) + """ + b, T_max, _ = mel_inputs.shape + log_alpha_scaled = mel_inputs.new_zeros((b, T_max, N)) + log_c = mel_inputs.new_zeros(b, T_max) + transition_matrix = mel_inputs.new_zeros((b, T_max, N)) + + # Saving for plotting later, will not have gradient tapes + means = [] + return log_c, log_alpha_scaled, transition_matrix, means + + @staticmethod + def _init_lstm_states(batch_size, hidden_state_dim, device_tensor): + r""" + Initialize Hidden and Cell states for LSTM Cell + + Args: + batch_size (Int): batch size + hidden_state_dim (Int): dimensions of the h and c + device_tensor (torch.FloatTensor): useful for the device and type + + Returns: + (torch.FloatTensor): shape (batch_size, hidden_state_dim) + can be hidden state for LSTM + (torch.FloatTensor): shape (batch_size, hidden_state_dim) + can be the cell state for LSTM + """ + return ( + device_tensor.new_zeros(batch_size, hidden_state_dim), + device_tensor.new_zeros(batch_size, hidden_state_dim), + ) + + def get_absorption_state_scaling_factor(self, mels_len, log_alpha_scaled, inputs_len, transition_vector): + """Returns the final scaling factor of absorption state + + Args: + mels_len (torch.IntTensor): Input size of mels to + get the last timestep of log_alpha_scaled + log_alpha_scaled (torch.FloatTEnsor): State probabilities + text_lengths (torch.IntTensor): length of the states to + mask the values of states lengths + ( + Useful when the batch has very different lengths, + when the length of an observation is less than + the number of max states, then the log alpha after + the state value is filled with -infs. So we mask + those values so that it only consider the states + which are needed for that length + ) + transition_vector (torch.FloatTensor): transtiion vector for each state per timestep + + Shapes: + - mels_len: (batch_size) + - log_alpha_scaled: (batch_size, N, T) + - text_lengths: (batch_size) + - transition_vector: (batch_size, N, T) + + Returns: + sum_final_log_c (torch.FloatTensor): (batch_size) + + """ + N = torch.max(inputs_len) + max_inputs_len = log_alpha_scaled.shape[2] + state_lengths_mask = sequence_mask(inputs_len, max_len=max_inputs_len) + + last_log_alpha_scaled_index = ( + (mels_len - 1).unsqueeze(-1).expand(-1, N).unsqueeze(1) + ) # Batch X Hidden State Size + last_log_alpha_scaled = torch.gather(log_alpha_scaled, 1, last_log_alpha_scaled_index).squeeze(1) + last_log_alpha_scaled = last_log_alpha_scaled.masked_fill(~state_lengths_mask, -float("inf")) + + last_transition_vector = torch.gather(transition_vector, 1, last_log_alpha_scaled_index).squeeze(1) + last_transition_probability = torch.sigmoid(last_transition_vector) + log_probability_of_transitioning = OverflowUtils.log_clamped(last_transition_probability) + + last_transition_probability_index = self.get_mask_for_last_item(inputs_len, inputs_len.device) + log_probability_of_transitioning = log_probability_of_transitioning.masked_fill( + ~last_transition_probability_index, -float("inf") + ) + final_log_c = last_log_alpha_scaled + log_probability_of_transitioning + + # If the length of the mel is less than the number of states it will select the -inf values leading to nan gradients + # Ideally, we should clean the dataset otherwise this is a little hack uncomment the line below + # final_log_c = final_log_c.clamp(min=torch.finfo(final_log_c.dtype).min) + + sum_final_log_c = torch.logsumexp(final_log_c, dim=1) + return sum_final_log_c + + @staticmethod + def get_mask_for_last_item(lengths, device, out_tensor=None): + """Returns n-1 mask for the last item in the sequence. + + Args: + lengths (torch.IntTensor): lengths in a batch + device (str, optional): Defaults to "cpu". + out_tensor (torch.Tensor, optional): uses the memory of a specific tensor. + Defaults to None. + + Returns: + - Shape: :math:`(b, max_len)` + """ + max_len = torch.max(lengths).item() + ids = ( + torch.arange(0, max_len, device=device) if out_tensor is None else torch.arange(0, max_len, out=out_tensor) + ) + mask = ids == lengths.unsqueeze(1) - 1 + return mask + + @torch.inference_mode() + def inference( + self, + inputs: torch.FloatTensor, + input_lens: torch.LongTensor, + sampling_temp: float, + max_sampling_time: int, + duration_threshold: float, + ): + """Inference from autoregressive neural HMM + + Args: + inputs (torch.FloatTensor): input states + - shape: :math:`(b, T, d)` + input_lens (torch.LongTensor): input state lengths + - shape: :math:`(b)` + sampling_temp (float): sampling temperature + max_sampling_temp (int): max sampling temperature + duration_threshold (float): duration threshold to switch to next state + - Use this to change the spearking rate of the synthesised audio + """ + + b = inputs.shape[0] + outputs = { + "hmm_outputs": [], + "hmm_outputs_len": [], + "alignments": [], + "input_parameters": [], + "output_parameters": [], + } + for i in range(b): + neural_hmm_outputs, states_travelled, input_parameters, output_parameters = self.sample( + inputs[i : i + 1], input_lens[i], sampling_temp, max_sampling_time, duration_threshold + ) + + outputs["hmm_outputs"].append(neural_hmm_outputs) + outputs["hmm_outputs_len"].append(neural_hmm_outputs.shape[0]) + outputs["alignments"].append(states_travelled) + outputs["input_parameters"].append(input_parameters) + outputs["output_parameters"].append(output_parameters) + + outputs["hmm_outputs"] = nn.utils.rnn.pad_sequence(outputs["hmm_outputs"], batch_first=True) + outputs["hmm_outputs_len"] = torch.tensor( + outputs["hmm_outputs_len"], dtype=input_lens.dtype, device=input_lens.device + ) + return outputs + + @torch.inference_mode() + def sample(self, inputs, input_lens, sampling_temp, max_sampling_time, duration_threshold): + """Samples an output from the parameter models + + Args: + inputs (torch.FloatTensor): input states + - shape: :math:`(1, T, d)` + input_lens (torch.LongTensor): input state lengths + - shape: :math:`(1)` + sampling_temp (float): sampling temperature + max_sampling_time (int): max sampling time + duration_threshold (float): duration threshold to switch to next state + + Returns: + outputs (torch.FloatTensor): Output Observations + - Shape: :math:`(T, output_dim)` + states_travelled (list[int]): Hidden states travelled + - Shape: :math:`(T)` + input_parameters (list[torch.FloatTensor]): Input parameters + output_parameters (list[torch.FloatTensor]): Output parameters + """ + states_travelled, outputs, t = [], [], 0 + + # Sample initial state + current_state = 0 + states_travelled.append(current_state) + + # Prepare autoregression + prenet_input = self.go_tokens.unsqueeze(0).expand(1, self.ar_order, self.frame_channels) + h_memory, c_memory = self._init_lstm_states(1, self.memory_rnn_dim, prenet_input) + + input_parameter_values = [] + output_parameter_values = [] + quantile = 1 + while True: + + memory_input = self.prenet(prenet_input.flatten(1).unsqueeze(0)) + # will be 1 while sampling + h_memory, c_memory = self.memory_rnn(memory_input.squeeze(0), (h_memory, c_memory)) + + z_t = inputs[:, current_state].unsqueeze(0) # Add fake time dimension + mean, std, transition_vector = self.output_net(h_memory, z_t) + + transition_probability = torch.sigmoid(transition_vector.flatten()) + staying_probability = torch.sigmoid(-transition_vector.flatten()) + + # Save for plotting + input_parameter_values.append([prenet_input, current_state]) + output_parameter_values.append([mean, std, transition_probability]) + + x_t = self.emission_model.sample(mean, std, sampling_temp=sampling_temp) + + # Prepare autoregressive input for next iteration + prenet_input = torch.cat((prenet_input, x_t), dim=1)[:, 1:] + + outputs.append(x_t.flatten()) + + transition_matrix = torch.cat((staying_probability, transition_probability)) + quantile *= staying_probability + if not self.deterministic_transition: + switch = transition_matrix.multinomial(1)[0].item() + else: + switch = quantile < duration_threshold + + if switch: + current_state += 1 + quantile = 1 + + states_travelled.append(current_state) + + if (current_state == input_lens) or (max_sampling_time and t == max_sampling_time - 1): + break + + t += 1 + + return ( + torch.stack(outputs, dim=0), + F.one_hot(input_lens.new_tensor(states_travelled)), + input_parameter_values, + output_parameter_values, + ) + + @staticmethod + def _initialize_log_state_priors(text_embeddings): + """Creates the log pi in forward algorithm. + + Args: + text_embeddings (torch.FloatTensor): used to create the log pi + on current device + + Shapes: + - text_embeddings: (B, T, D_out_enc) + """ + N = text_embeddings.shape[1] + log_state_priors = text_embeddings.new_full([N], -float("inf")) + log_state_priors[0] = 0.0 + return log_state_priors + + +class TransitionModel(nn.Module): + """Transition Model of the HMM, it represents the probability of transitioning + form current state to all other states""" + + def forward(self, log_alpha_scaled, transition_vector, inputs_len): # pylint: disable=no-self-use + r""" + product of the past state with transitional probabilities in log space + + Args: + log_alpha_scaled (torch.Tensor): Multiply previous timestep's alphas by + transition matrix (in log domain) + - shape: (batch size, N) + transition_vector (torch.tensor): transition vector for each state + - shape: (N) + inputs_len (int tensor): Lengths of states in a batch + - shape: (batch) + + Returns: + out (torch.FloatTensor): log probability of transitioning to each state + """ + transition_p = torch.sigmoid(transition_vector) + staying_p = torch.sigmoid(-transition_vector) + + log_staying_probability = OverflowUtils.log_clamped(staying_p) + log_transition_probability = OverflowUtils.log_clamped(transition_p) + + staying = log_alpha_scaled + log_staying_probability + leaving = log_alpha_scaled + log_transition_probability + leaving = leaving.roll(1, dims=1) + leaving[:, 0] = -float("inf") + inputs_len_mask = sequence_mask(inputs_len) + out = OverflowUtils.logsumexp(torch.stack((staying, leaving), dim=2), dim=2) + out = out.masked_fill(~inputs_len_mask, -float("inf")) # There are no states to contribute to the loss + return out + + +class EmissionModel(nn.Module): + """Emission Model of the HMM, it represents the probability of + emitting an observation based on the current state""" + + def __init__(self) -> None: + super().__init__() + self.distribution_function: tdist.Distribution = tdist.normal.Normal + + def sample(self, means, stds, sampling_temp): + return self.distribution_function(means, stds * sampling_temp).sample() if sampling_temp > 0 else means + + def forward(self, x_t, means, stds, state_lengths): + r"""Calculates the log probability of the the given data (x_t) + being observed from states with given means and stds + Args: + x_t (float tensor) : observation at current time step + - shape: (batch, feature_dim) + means (float tensor): means of the distributions of hidden states + - shape: (batch, hidden_state, feature_dim) + stds (float tensor): standard deviations of the distributions of the hidden states + - shape: (batch, hidden_state, feature_dim) + state_lengths (int tensor): Lengths of states in a batch + - shape: (batch) + + Returns: + out (float tensor): observation log likelihoods, + expressing the probability of an observation + being generated from a state i + shape: (batch, hidden_state) + """ + emission_dists = self.distribution_function(means, stds) + out = emission_dists.log_prob(x_t.unsqueeze(1)) + state_lengths_mask = sequence_mask(state_lengths).unsqueeze(2) + out = torch.sum(out * state_lengths_mask, dim=2) + return out diff --git a/TTS/tts/layers/overflow/plotting_utils.py b/TTS/tts/layers/overflow/plotting_utils.py new file mode 100644 index 00000000..a9f9c301 --- /dev/null +++ b/TTS/tts/layers/overflow/plotting_utils.py @@ -0,0 +1,74 @@ +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import torch + + +def validate_numpy_array(value: Any): + r""" + Validates the input and makes sure it returns a numpy array (i.e on CPU) + + Args: + value (Any): the input value + + Raises: + TypeError: if the value is not a numpy array or torch tensor + + Returns: + np.ndarray: numpy array of the value + """ + if isinstance(value, np.ndarray): + pass + elif isinstance(value, list): + value = np.array(value) + elif torch.is_tensor(value): + value = value.cpu().numpy() + else: + raise TypeError("Value must be a numpy array, a torch tensor or a list") + + return value + + +def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder): + """Get the most probable state means from the log_alpha_scaled. + + Args: + log_alpha_scaled (torch.Tensor): Log alpha scaled values. + - Shape: :math:`(T, N)` + means (torch.Tensor): Means of the states. + - Shape: :math:`(N, T, D_out)` + decoder (torch.nn.Module): Decoder module to decode the latent to melspectrogram + """ + max_state_numbers = torch.max(log_alpha_scaled, dim=1)[1] + max_len = means.shape[0] + n_mel_channels = means.shape[2] + max_state_numbers = max_state_numbers.unsqueeze(1).unsqueeze(1).expand(max_len, 1, n_mel_channels) + means = torch.gather(means, 1, max_state_numbers).squeeze(1).to(log_alpha_scaled.dtype) + mel = ( + decoder(means.T.unsqueeze(0), torch.tensor([means.shape[0]], device=means.device), reverse=True)[0].squeeze(0).T + ) + return mel + + +def plot_transition_probabilities_to_numpy(states, transition_probabilities, output_fig=False): + """Generates trainsition probabilities plot for the states and the probability of transition. + + Args: + states (torch.IntTensor): the states + transition_probabilities (torch.FloatTensor): the transition probabilities + """ + states = validate_numpy_array(states) + transition_probabilities = validate_numpy_array(transition_probabilities) + + fig, ax = plt.subplots(figsize=(30, 3)) + ax.plot(transition_probabilities, "o") + ax.set_title("Transition probability of state") + ax.set_xlabel("hidden state") + ax.set_ylabel("probability") + ax.set_xticks([i for i in range(len(transition_probabilities))]) # pylint: disable=unnecessary-comprehension + ax.set_xticklabels([int(x) for x in states], rotation=90) + plt.tight_layout() + if not output_fig: + plt.close() + return fig diff --git a/TTS/tts/models/overflow.py b/TTS/tts/models/overflow.py new file mode 100644 index 00000000..ffb68c0b --- /dev/null +++ b/TTS/tts/models/overflow.py @@ -0,0 +1,393 @@ +import os +from typing import Dict, List, Union + +import torch +from coqpit import Coqpit +from torch import nn +from trainer.logging.tensorboard_logger import TensorboardLogger + +from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils +from TTS.tts.layers.overflow.decoder import Decoder +from TTS.tts.layers.overflow.neural_hmm import NeuralHMM +from TTS.tts.layers.overflow.plotting_utils import ( + get_spec_from_most_probable_state, + plot_transition_probabilities_to_numpy, +) +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.generic_utils import format_aux_input +from TTS.utils.io import load_fsspec + + +class Overflow(BaseTTS): + """OverFlow TTS model. + + Paper:: + https://arxiv.org/abs/2211.06892 + + Paper abstract:: + Neural HMMs are a type of neural transducer recently proposed for + sequence-to-sequence modelling in text-to-speech. They combine the best features + of classic statistical speech synthesis and modern neural TTS, requiring less + data and fewer training updates, and are less prone to gibberish output caused + by neural attention failures. In this paper, we combine neural HMM TTS with + normalising flows for describing the highly non-Gaussian distribution of speech + acoustics. The result is a powerful, fully probabilistic model of durations and + acoustics that can be trained using exact maximum likelihood. Compared to + dominant flow-based acoustic models, our approach integrates autoregression for + improved modelling of long-range dependences such as utterance-level prosody. + Experiments show that a system based on our proposal gives more accurate + pronunciations and better subjective speech quality than comparable methods, + whilst retaining the original advantages of neural HMMs. Audio examples and code + are available at https://shivammehta25.github.io/OverFlow/. + + Note: + - Neural HMMs uses flat start initialization i.e it computes the means and std and transition probabilities + of the dataset and uses them to initialize the model. This benefits the model and helps with faster learning + If you change the dataset or want to regenerate the parameters change the `force_generate_statistics` and + `mel_statistics_parameter_path` accordingly. + + - To enable multi-GPU training, set the `use_grad_checkpointing=False` in config. + This will significantly increase the memory usage. This is because to compute + the actual data likelihood (not an approximation using MAS/Viterbi) we must use + all the states at the previous time step during the forward pass to decide the + probability distribution at the current step i.e the difference between the forward + algorithm and viterbi approximation. + + Check :class:`TTS.tts.configs.overflow.OverFlowConfig` for class arguments. + """ + + def __init__( + self, + config: "OverFlowConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + self.config = config + for key in config: + setattr(self, key, config[key]) + + self.decoder_output_dim = config.out_channels + + self.encoder = Encoder(config.num_chars, config.state_per_phone, config.encoder_in_out_features) + self.neural_hmm = NeuralHMM( + frame_channels=self.out_channels, + ar_order=self.ar_order, + deterministic_transition=self.deterministic_transition, + encoder_dim=self.encoder_in_out_features, + prenet_type=self.prenet_type, + prenet_dim=self.prenet_dim, + prenet_n_layers=self.prenet_n_layers, + prenet_dropout=self.prenet_dropout, + prenet_dropout_at_inference=self.prenet_dropout_at_inference, + memory_rnn_dim=self.memory_rnn_dim, + outputnet_size=self.outputnet_size, + flat_start_params=self.flat_start_params, + std_floor=self.std_floor, + use_grad_checkpointing=self.use_grad_checkpointing, + ) + + self.decoder = Decoder( + self.out_channels, + self.hidden_channels_dec, + self.kernel_size_dec, + self.dilation_rate, + self.num_flow_blocks_dec, + self.num_block_layers, + dropout_p=self.dropout_p_dec, + num_splits=self.num_splits, + num_squeeze=self.num_squeeze, + sigmoid_scale=self.sigmoid_scale, + c_in_channels=self.c_in_channels, + ) + + self.register_buffer("mean", torch.tensor(0)) + self.register_buffer("std", torch.tensor(1)) + + # self.mean = nn.Parameter(torch.zeros(1), requires_grad=False) + # self.std = nn.Parameter(torch.ones(1), requires_grad=False) + + def update_mean_std(self, statistics_dict: Dict): + self.mean.data = torch.tensor(statistics_dict["mean"]) + self.std.data = torch.tensor(statistics_dict["std"]) + + def preprocess_batch(self, text, text_len, mels, mel_len): + if self.mean.item() == 0 or self.std.item() == 1: + statistics_dict = torch.load(self.mel_statistics_parameter_path) + self.update_mean_std(statistics_dict) + + mels = self.normalize(mels) + return text, text_len, mels, mel_len + + def normalize(self, x): + return x.sub(self.mean).div(self.std) + + def inverse_normalize(self, x): + return x.mul(self.std).add(self.mean) + + def forward(self, text, text_len, mels, mel_len): + """ + Forward pass for training and computing the log likelihood of a given batch. + + Shapes: + Shapes: + text: :math:`[B, T_in]` + text_len: :math:`[B]` + mels: :math:`[B, T_out, C]` + mel_len: :math:`[B]` + """ + text, text_len, mels, mel_len = self.preprocess_batch(text, text_len, mels, mel_len) + encoder_outputs, encoder_output_len = self.encoder(text, text_len) + z, z_lengths, logdet = self.decoder(mels.transpose(1, 2), mel_len) + log_probs, fwd_alignments, transition_vectors, means = self.neural_hmm( + encoder_outputs, encoder_output_len, z, z_lengths + ) + + outputs = { + "log_probs": log_probs + logdet, + "alignments": fwd_alignments, + "transition_vectors": transition_vectors, + "means": means, + } + + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + + outputs = self.forward( + text=text_input, + text_len=text_lengths, + mels=mel_input, + mel_len=mel_lengths, + ) + + loss_dict = criterion(outputs["log_probs"]) + + return outputs, loss_dict + + def eval_step(self, batch: Dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def _format_aux_input(self, aux_input: Dict, default_input_dict): + """Set missing fields to their default value. + + Args: + aux_inputs (Dict): Dictionary containing the auxiliary inputs. + """ + default_input_dict.update( + { + "sampling_temp": self.sampling_temp, + "max_sampling_time": self.max_sampling_time, + "duration_threshold": self.duration_threshold, + } + ) + if aux_input: + return format_aux_input(aux_input, default_input_dict) + return None + + @torch.no_grad() + def inference( + self, + text: torch.Tensor, + aux_input={"x_lengths": None, "sampling_temp": None, "max_sampling_time": None, "duration_threshold": None}, + ): # pylint: disable=dangerous-default-value + """Sampling from the model + + Args: + text (torch.Tensor): :math:`[B, T_in]` + aux_inputs (_type_, optional): _description_. Defaults to None. + + Returns: + outputs: Dictionary containing the following + - mel (torch.Tensor): :math:`[B, T_out, C]` + - hmm_outputs_len (torch.Tensor): :math:`[B]` + - state_travelled (List[List[int]]): List of lists containing the state travelled for each sample in the batch. + - input_parameters (list[torch.FloatTensor]): Input parameters to the neural HMM. + - output_parameters (list[torch.FloatTensor]): Output parameters to the neural HMM. + """ + default_input_dict = { + "x_lengths": torch.sum(text != 0, dim=1), + } + aux_input = self._format_aux_input(aux_input, default_input_dict) + encoder_outputs, encoder_output_len = self.encoder.inference(text, aux_input["x_lengths"]) + outputs = self.neural_hmm.inference( + encoder_outputs, + encoder_output_len, + sampling_temp=aux_input["sampling_temp"], + max_sampling_time=aux_input["max_sampling_time"], + duration_threshold=aux_input["duration_threshold"], + ) + + mels, mel_outputs_len, _ = self.decoder( + outputs["hmm_outputs"].transpose(1, 2), outputs["hmm_outputs_len"], reverse=True + ) + mels = self.inverse_normalize(mels.transpose(1, 2)) + outputs.update({"model_outputs": mels, "model_outputs_len": mel_outputs_len}) + outputs["alignments"] = OverflowUtils.double_pad(outputs["alignments"]) + return outputs + + @staticmethod + def get_criterion(): + return NLLLoss() + + @staticmethod + def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + verbose (bool): If True, print init messages. Defaults to True. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config, verbose) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return Overflow(new_config, ap, tokenizer, speaker_manager) + + def load_checkpoint( + self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + self.decoder.store_inverse() + assert not self.training + + def on_init_start(self, trainer): + """If the current dataset does not have normalisation statistics and initialisation transition_probability it computes them otherwise loads.""" + if not os.path.isfile(trainer.config.mel_statistics_parameter_path) or trainer.config.force_generate_statistics: + dataloader = trainer.get_train_dataloader( + training_assets=None, samples=trainer.train_samples, verbose=False + ) + print( + f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..." + ) + data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( + dataloader, trainer.config.out_channels, trainer.config.state_per_phone + ) + print( + f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}" + ) + statistics = { + "mean": data_mean.item(), + "std": data_std.item(), + "init_transition_prob": init_transition_prob.item(), + } + torch.save(statistics, trainer.config.mel_statistics_parameter_path) + + else: + print( + f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..." + ) + statistics = torch.load(trainer.config.mel_statistics_parameter_path) + data_mean, data_std, init_transition_prob = ( + statistics["mean"], + statistics["std"], + statistics["init_transition_prob"], + ) + print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}") + + trainer.config.flat_start_params["transition_p"] = ( + init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob + ) + OverflowUtils.update_flat_start_transition(trainer.model, init_transition_prob) + trainer.model.update_mean_std(statistics) + + @torch.inference_mode() + def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use, unused-argument + alignments, transition_vectors = outputs["alignments"], outputs["transition_vectors"] + means = torch.stack(outputs["means"], dim=1) + + figures = { + "alignment": plot_alignment(alignments[0].exp(), title="Forward alignment", fig_size=(20, 20)), + "log_alignment": plot_alignment( + alignments[0].exp(), title="Forward log alignment", plot_log=True, fig_size=(20, 20) + ), + "transition_vectors": plot_alignment(transition_vectors[0], title="Transition vectors", fig_size=(20, 20)), + "mel_from_most_probable_state": plot_spectrogram( + get_spec_from_most_probable_state(alignments[0], means[0], self.decoder), fig_size=(12, 3) + ), + "mel_target": plot_spectrogram(batch["mel_input"][0], fig_size=(12, 3)), + } + + # sample one item from the batch -1 will give the smalles item + print(" | > Synthesising audio from the model...") + inference_output = self.inference( + batch["text_input"][-1].unsqueeze(0), aux_input={"x_lenghts": batch["text_lengths"][-1].unsqueeze(0)} + ) + figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3)) + + states = [p[1] for p in inference_output["input_parameters"][0]] + transition_probability_synthesising = [p[2].cpu().numpy() for p in inference_output["output_parameters"][0]] + + for i in range((len(transition_probability_synthesising) // 200) + 1): + start = i * 200 + end = (i + 1) * 200 + figures[f"synthesised_transition_probabilities/{i}"] = plot_transition_probabilities_to_numpy( + states[start:end], transition_probability_synthesising[start:end] + ) + + audio = ap.inv_melspectrogram(inference_output["model_outputs"][0].T.cpu().numpy()) + return figures, {"audios": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=unused-argument + """Log training progress.""" + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int + ): # pylint: disable=unused-argument + """Compute and log evaluation metrics.""" + # Plot model parameters histograms + if isinstance(logger, TensorboardLogger): + # I don't know if any other loggers supports this + for tag, value in self.named_parameters(): + tag = tag.replace(".", "/") + logger.writer.add_histogram(tag, value.data.cpu().numpy(), steps) + + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs[1], self.ap.sample_rate) + logger.test_figures(steps, outputs[0]) + + +class NLLLoss(nn.Module): + """Negative log likelihood loss.""" + + def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use + """Compute the loss. + + Args: + logits (Tensor): [B, T, D] + + Returns: + Tensor: [1] + + """ + return_dict = {} + return_dict["loss"] = -log_prob.mean() + return return_dict diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index 78c12981..a823738d 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -3,18 +3,21 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np import torch +from matplotlib.colors import LogNorm matplotlib.use("Agg") -def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False): +def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False, plot_log=False): if isinstance(alignment, torch.Tensor): alignment_ = alignment.detach().cpu().numpy().squeeze() else: alignment_ = alignment alignment_ = alignment_.astype(np.float32) if alignment_.dtype == np.float16 else alignment_ fig, ax = plt.subplots(figsize=fig_size) - im = ax.imshow(alignment_.T, aspect="auto", origin="lower", interpolation="none") + im = ax.imshow( + alignment_.T, aspect="auto", origin="lower", interpolation="none", norm=LogNorm() if plot_log else None + ) fig.colorbar(im, ax=ax) xlabel = "Decoder timestep" if info is not None: diff --git a/recipes/ljspeech/overflow/lj_parameters.pt b/recipes/ljspeech/overflow/lj_parameters.pt new file mode 100644 index 0000000000000000000000000000000000000000..a8d5dc34dd8b1aa4e8e55463cc40803762164a9d GIT binary patch literal 507 zcmWIWW@cev;NW1u0AdV03^`fx1&Kw8xv3?oMaB9li6x181=%@nP81aB#KK?3Ypv;7@D#UfP_SVL}p%QNqk9BVqS4(NoIat zd_hruQX#Xu{gwWY3P9;ny#Q}^jt-Z@AA^BB5DviUL{SDUG$)n;-3c_*%}EXJLYPaj z)5csN>n_eq4`l*c1;PQ|j35dgCdf_^14*C&>>-412D160DEb!zd8lSUy%6Bd#-;;R sD#xq~*93MOkO!mNp)x{@5C%{mJBT$C$_G&a-mGjOF(x1csfVZq0Q+`d7ytkO literal 0 HcmV?d00001 diff --git a/recipes/ljspeech/overflow/train_overflow.py b/recipes/ljspeech/overflow/train_overflow.py new file mode 100644 index 00000000..e05e399d --- /dev/null +++ b/recipes/ljspeech/overflow/train_overflow.py @@ -0,0 +1,96 @@ +import os + +from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseAudioConfig +from TTS.tts.configs.overflow_config import OverflowConfig +from TTS.tts.configs.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models.overflow import Overflow +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.utils.audio import AudioProcessor + +output_path = os.path.dirname(os.path.abspath(__file__)) + +# init configs +dataset_config = BaseDatasetConfig( + formatter="ljspeech", meta_file_train="metadata.csv", path=os.path.join("data", "LJSpeech-1.1/") +) + +audio_config = BaseAudioConfig( + sample_rate=22050, + do_trim_silence=True, + trim_db=60.0, + signal_norm=False, + mel_fmin=0.0, + mel_fmax=8000, + spec_gain=1.0, + log_func="np.log", + ref_level_db=20, + preemphasis=0.0, +) + +config = OverflowConfig( # This is the config that is saved for the future use + run_name="overflow_ljspeech", + audio=audio_config, + batch_size=30, + eval_batch_size=16, + num_loader_workers=4, + num_eval_loader_workers=4, + run_eval=True, + test_delay_epochs=-1, + epochs=1000, + text_cleaner="phoneme_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + precompute_num_workers=8, + mel_statistics_parameter_path=os.path.join(output_path, "lj_parameters.pt"), + force_generate_statistics=False, + print_step=1, + print_eval=True, + mixed_precision=True, + output_path=output_path, + datasets=[dataset_config], +) + +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) + +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. +train_samples, eval_samples = load_tts_samples( + dataset_config, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, +) + +# INITIALIZE THE MODEL +# Models take a config object and a speaker manager as input +# Config defines the details of the model like the number of layers, the size of the embedding, etc. +# Speaker manager is used by multi-speaker models. +model = Overflow(config, ap, tokenizer) + + +# init the trainer and 🚀 +trainer = Trainer( + TrainerArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + gpu=1, +) +trainer.fit() diff --git a/tests/tts_tests/test_overflow.py b/tests/tts_tests/test_overflow.py new file mode 100644 index 00000000..01c44719 --- /dev/null +++ b/tests/tts_tests/test_overflow.py @@ -0,0 +1,399 @@ +import os +import random +import unittest +from copy import deepcopy + +import torch + +from tests import get_tests_output_path +from TTS.tts.configs.overflow_config import OverflowConfig +from TTS.tts.layers.overflow.common_layers import Encoder, Outputnet, OverflowUtils +from TTS.tts.layers.overflow.decoder import Decoder +from TTS.tts.layers.overflow.neural_hmm import EmissionModel, NeuralHMM, TransitionModel +from TTS.tts.models.overflow import Overflow +from TTS.tts.utils.helpers import sequence_mask +from TTS.utils.audio import AudioProcessor + +# pylint: disable=unused-variable + +torch.manual_seed(1) +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +config_global = OverflowConfig(num_chars=24) +ap = AudioProcessor.init_from_config(config_global) + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") +parameter_path = os.path.join(get_tests_output_path(), "lj_parameters.pt") + +torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path) + + +def _create_inputs(batch_size=8): + max_len_t, max_len_m = random.randint(25, 50), random.randint(50, 80) + input_dummy = torch.randint(0, 24, (batch_size, max_len_t)).long().to(device) + input_lengths = torch.randint(20, max_len_t, (batch_size,)).long().to(device).sort(descending=True)[0] + input_lengths[0] = max_len_t + input_dummy = input_dummy * sequence_mask(input_lengths) + mel_spec = torch.randn(batch_size, max_len_m, config_global.audio["num_mels"]).to(device) + mel_lengths = torch.randint(40, max_len_m, (batch_size,)).long().to(device).sort(descending=True)[0] + mel_lengths[0] = max_len_m + mel_spec = mel_spec * sequence_mask(mel_lengths).unsqueeze(2) + return input_dummy, input_lengths, mel_spec, mel_lengths + + +def get_model(config=None): + if config is None: + config = config_global + config.mel_statistics_parameter_path = parameter_path + model = Overflow(config) + model = model.to(device) + return model + + +def reset_all_weights(model): + """ + refs: + - https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6 + - https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch + - https://pytorch.org/docs/stable/generated/torch.nn.Module.html + """ + + @torch.no_grad() + def weight_reset(m): + # - check if the current module has reset_parameters & if it's callabed called it on m + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + + # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html + model.apply(fn=weight_reset) + + +class TestOverflow(unittest.TestCase): + def test_forward(self): + model = get_model() + input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs() + outputs = model(input_dummy, input_lengths, mel_spec, mel_lengths) + self.assertEqual(outputs["log_probs"].shape, (input_dummy.shape[0],)) + self.assertEqual(model.state_per_phone * max(input_lengths), outputs["alignments"].shape[2]) + + def test_inference(self): + model = get_model() + input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs() + output_dict = model.inference(input_dummy) + self.assertEqual(output_dict["model_outputs"].shape[2], config_global.out_channels) + + def test_init_from_config(self): + config = deepcopy(config_global) + config.mel_statistics_parameter_path = parameter_path + config.prenet_dim = 256 + model = Overflow.init_from_config(config_global) + self.assertEqual(model.prenet_dim, config.prenet_dim) + + +class TestOverflowEncoder(unittest.TestCase): + @staticmethod + def get_encoder(state_per_phone): + config = deepcopy(config_global) + config.state_per_phone = state_per_phone + config.num_chars = 24 + return Encoder(config.num_chars, config.state_per_phone, config.prenet_dim, config.encoder_n_convolutions).to( + device + ) + + def test_forward_with_state_per_phone_multiplication(self): + for s_p_p in [1, 2, 3]: + input_dummy, input_lengths, _, _ = _create_inputs() + model = self.get_encoder(s_p_p) + x, x_len = model(input_dummy, input_lengths) + self.assertEqual(x.shape[1], input_dummy.shape[1] * s_p_p) + + def test_inference_with_state_per_phone_multiplication(self): + for s_p_p in [1, 2, 3]: + input_dummy, input_lengths, _, _ = _create_inputs() + model = self.get_encoder(s_p_p) + x, x_len = model.inference(input_dummy, input_lengths) + self.assertEqual(x.shape[1], input_dummy.shape[1] * s_p_p) + + +class TestOverflowUtils(unittest.TestCase): + def test_logsumexp(self): + a = torch.randn(10) # random numbers + self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all()) + + a = torch.zeros(10) # all zeros + self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all()) + + a = torch.ones(10) # all ones + self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all()) + + +class TestOverflowDecoder(unittest.TestCase): + @staticmethod + def _get_decoder(num_flow_blocks_dec=None, hidden_channels_dec=None, reset_weights=True): + config = deepcopy(config_global) + config.num_flow_blocks_dec = ( + num_flow_blocks_dec if num_flow_blocks_dec is not None else config.num_flow_blocks_dec + ) + config.hidden_channels_dec = ( + hidden_channels_dec if hidden_channels_dec is not None else config.hidden_channels_dec + ) + config.dropout_p_dec = 0.0 # turn off dropout to check invertibility + decoder = Decoder( + config.out_channels, + config.hidden_channels_dec, + config.kernel_size_dec, + config.dilation_rate, + config.num_flow_blocks_dec, + config.num_block_layers, + config.dropout_p_dec, + config.num_splits, + config.num_squeeze, + config.sigmoid_scale, + config.c_in_channels, + ).to(device) + if reset_weights: + reset_all_weights(decoder) + return decoder + + def test_decoder_forward_backward(self): + for num_flow_blocks_dec in [8, None]: + for hidden_channels_dec in [100, None]: + decoder = self._get_decoder(num_flow_blocks_dec, hidden_channels_dec) + _, _, mel_spec, mel_lengths = _create_inputs() + z, z_len, _ = decoder(mel_spec.transpose(1, 2), mel_lengths) + mel_spec_, mel_lengths_, _ = decoder(z, z_len, reverse=True) + mask = sequence_mask(z_len).unsqueeze(1) + mel_spec = mel_spec[:, : z.shape[2], :].transpose(1, 2) * mask + z = z * mask + self.assertTrue( + torch.isclose(mel_spec, mel_spec_, atol=1e-2).all(), + f"num_flow_blocks_dec={num_flow_blocks_dec}, hidden_channels_dec={hidden_channels_dec}", + ) + + +class TestNeuralHMM(unittest.TestCase): + @staticmethod + def _get_neural_hmm(deterministic_transition=None): + config = deepcopy(config_global) + neural_hmm = NeuralHMM( + config.out_channels, + config.ar_order, + config.deterministic_transition if deterministic_transition is None else deterministic_transition, + config.encoder_in_out_features, + config.prenet_type, + config.prenet_dim, + config.prenet_n_layers, + config.prenet_dropout, + config.prenet_dropout_at_inference, + config.memory_rnn_dim, + config.outputnet_size, + config.flat_start_params, + config.std_floor, + ).to(device) + return neural_hmm + + @staticmethod + def _get_emission_model(): + return EmissionModel().to(device) + + @staticmethod + def _get_transition_model(): + return TransitionModel().to(device) + + @staticmethod + def _get_embedded_input(): + input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs() + input_dummy = torch.nn.Embedding(config_global.num_chars, config_global.encoder_in_out_features).to(device)( + input_dummy + ) + return input_dummy, input_lengths, mel_spec, mel_lengths + + def test_neural_hmm_forward(self): + input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() + neural_hmm = self._get_neural_hmm() + log_prob, log_alpha_scaled, transition_matrix, means = neural_hmm( + input_dummy, input_lengths, mel_spec.transpose(1, 2), mel_lengths + ) + self.assertEqual(log_prob.shape, (input_dummy.shape[0],)) + self.assertEqual(log_alpha_scaled.shape, transition_matrix.shape) + + def test_mask_lengths(self): + input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() + neural_hmm = self._get_neural_hmm() + log_prob, log_alpha_scaled, transition_matrix, means = neural_hmm( + input_dummy, input_lengths, mel_spec.transpose(1, 2), mel_lengths + ) + log_c = torch.randn(mel_spec.shape[0], mel_spec.shape[1], device=device) + log_c, log_alpha_scaled = neural_hmm._mask_lengths( # pylint: disable=protected-access + mel_lengths, log_c, log_alpha_scaled + ) + assertions = [] + for i in range(mel_spec.shape[0]): + assertions.append(log_c[i, mel_lengths[i] :].sum() == 0.0) + self.assertTrue(all(assertions), "Incorrect masking") + assertions = [] + for i in range(mel_spec.shape[0]): + assertions.append(log_alpha_scaled[i, mel_lengths[i] :, : input_lengths[i]].sum() == 0.0) + self.assertTrue(all(assertions), "Incorrect masking") + + def test_process_ar_timestep(self): + model = self._get_neural_hmm() + input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() + + h_post_prenet, c_post_prenet = model._init_lstm_states( # pylint: disable=protected-access + input_dummy.shape[0], config_global.memory_rnn_dim, mel_spec + ) + h_post_prenet, c_post_prenet = model._process_ar_timestep( # pylint: disable=protected-access + 1, + mel_spec, + h_post_prenet, + c_post_prenet, + ) + + self.assertEqual(h_post_prenet.shape, (input_dummy.shape[0], config_global.memory_rnn_dim)) + self.assertEqual(c_post_prenet.shape, (input_dummy.shape[0], config_global.memory_rnn_dim)) + + def test_add_go_token(self): + model = self._get_neural_hmm() + input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() + + out = model._add_go_token(mel_spec) # pylint: disable=protected-access + self.assertEqual(out.shape, mel_spec.shape) + self.assertTrue((out[:, 1:] == mel_spec[:, :-1]).all(), "Go token not appended properly") + + def test_forward_algorithm_variables(self): + model = self._get_neural_hmm() + input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() + + ( + log_c, + log_alpha_scaled, + transition_matrix, + _, + ) = model._initialize_forward_algorithm_variables( # pylint: disable=protected-access + mel_spec, input_dummy.shape[1] * config_global.state_per_phone + ) + + self.assertEqual(log_c.shape, (mel_spec.shape[0], mel_spec.shape[1])) + self.assertEqual( + log_alpha_scaled.shape, + ( + mel_spec.shape[0], + mel_spec.shape[1], + input_dummy.shape[1] * config_global.state_per_phone, + ), + ) + self.assertEqual( + transition_matrix.shape, + (mel_spec.shape[0], mel_spec.shape[1], input_dummy.shape[1] * config_global.state_per_phone), + ) + + def test_get_absorption_state_scaling_factor(self): + model = self._get_neural_hmm() + input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() + input_lengths = input_lengths * config_global.state_per_phone + ( + log_c, + log_alpha_scaled, + transition_matrix, + _, + ) = model._initialize_forward_algorithm_variables( # pylint: disable=protected-access + mel_spec, input_dummy.shape[1] * config_global.state_per_phone + ) + log_alpha_scaled = torch.rand_like(log_alpha_scaled).clamp(1e-3) + transition_matrix = torch.randn_like(transition_matrix).sigmoid().log() + sum_final_log_c = model.get_absorption_state_scaling_factor( + mel_lengths, log_alpha_scaled, input_lengths, transition_matrix + ) + + text_mask = ~sequence_mask(input_lengths) + transition_prob_mask = ~model.get_mask_for_last_item(input_lengths, device=input_lengths.device) + + outputs = [] + + for i in range(input_dummy.shape[0]): + last_log_alpha_scaled = log_alpha_scaled[i, mel_lengths[i] - 1].masked_fill(text_mask[i], -float("inf")) + log_last_transition_probability = OverflowUtils.log_clamped( + torch.sigmoid(transition_matrix[i, mel_lengths[i] - 1]) + ).masked_fill(transition_prob_mask[i], -float("inf")) + outputs.append(last_log_alpha_scaled + log_last_transition_probability) + + sum_final_log_c_computed = torch.logsumexp(torch.stack(outputs), dim=1) + + self.assertTrue(torch.isclose(sum_final_log_c_computed, sum_final_log_c).all()) + + def test_inference(self): + model = self._get_neural_hmm() + input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() + for temp in [0.334, 0.667, 1.0]: + outputs = model.inference( + input_dummy, input_lengths, temp, config_global.max_sampling_time, config_global.duration_threshold + ) + self.assertEqual(outputs["hmm_outputs"].shape[-1], outputs["input_parameters"][0][0][0].shape[-1]) + self.assertEqual( + outputs["output_parameters"][0][0][0].shape[-1], outputs["input_parameters"][0][0][0].shape[-1] + ) + self.assertEqual(len(outputs["alignments"]), input_dummy.shape[0]) + + def test_emission_model(self): + model = self._get_emission_model() + input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() + x_t = torch.randn(input_dummy.shape[0], config_global.out_channels).to(device) + means = torch.randn(input_dummy.shape[0], input_dummy.shape[1], config_global.out_channels).to(device) + std = torch.rand_like(means).to(device).clamp_(1e-3) # std should be positive + out = model(x_t, means, std, input_lengths) + self.assertEqual(out.shape, (input_dummy.shape[0], input_dummy.shape[1])) + + # testing sampling + for temp in [0, 0.334, 0.667]: + out = model.sample(means, std, 0) + self.assertEqual(out.shape, means.shape) + if temp == 0: + self.assertTrue(torch.isclose(out, means).all()) + + def test_transition_model(self): + model = self._get_transition_model() + input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() + prev_t_log_scaled_alph = torch.randn(input_dummy.shape[0], input_lengths.max()).to(device) + transition_vector = torch.randn(input_lengths.max()).to(device) + out = model(prev_t_log_scaled_alph, transition_vector, input_lengths) + self.assertEqual(out.shape, (input_dummy.shape[0], input_lengths.max())) + + +class TestOverflowOutputNet(unittest.TestCase): + @staticmethod + def _get_outputnet(): + config = deepcopy(config_global) + outputnet = Outputnet( + config.encoder_in_out_features, + config.memory_rnn_dim, + config.out_channels, + config.outputnet_size, + config.flat_start_params, + config.std_floor, + ).to(device) + return outputnet + + @staticmethod + def _get_embedded_input(): + input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs() + input_dummy = torch.nn.Embedding(config_global.num_chars, config_global.encoder_in_out_features).to(device)( + input_dummy + ) + one_timestep_frame = torch.randn(input_dummy.shape[0], config_global.memory_rnn_dim).to(device) + return input_dummy, one_timestep_frame + + def test_outputnet_forward_with_flat_start(self): + model = self._get_outputnet() + input_dummy, one_timestep_frame = self._get_embedded_input() + mean, std, transition_vector = model(one_timestep_frame, input_dummy) + self.assertTrue(torch.isclose(mean, torch.tensor(model.flat_start_params["mean"] * 1.0)).all()) + self.assertTrue(torch.isclose(std, torch.tensor(model.flat_start_params["std"] * 1.0)).all()) + self.assertTrue( + torch.isclose( + transition_vector.sigmoid(), torch.tensor(model.flat_start_params["transition_p"] * 1.0) + ).all() + ) diff --git a/tests/tts_tests/test_overflow_train.py b/tests/tts_tests/test_overflow_train.py new file mode 100644 index 00000000..86fa60af --- /dev/null +++ b/tests/tts_tests/test_overflow_train.py @@ -0,0 +1,92 @@ +import glob +import json +import os +import shutil + +import torch +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.overflow_config import OverflowConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") +parameter_path = os.path.join(get_tests_output_path(), "lj_parameters.pt") + +torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path) + +config = OverflowConfig( + batch_size=3, + eval_batch_size=3, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="phoneme_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"), + run_eval=True, + test_delay_epochs=-1, + mel_statistics_parameter_path=parameter_path, + epochs=1, + print_step=1, + test_sentences=[ + "Be a voice, not an echo.", + ], + print_eval=True, + max_sampling_time=50, +) +config.audio.do_trim_silence = True +config.audio.trim_db = 60 +config.save_json(config_path) + + +# train the model for one epoch when mel parameters exists +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.formatter ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.test_delay_epochs 0 " +) +run_cli(command_train) + + +# train the model for one epoch when mel parameters have to be computed from the dataset +if os.path.exists(parameter_path): + os.remove(parameter_path) +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.formatter ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.test_delay_epochs 0 " +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") + +# Check integrity of the config +with open(continue_config_path, "r", encoding="utf-8") as f: + config_loaded = json.load(f) +assert config_loaded["characters"] is not None +assert config_loaded["output_path"] in continue_path +assert config_loaded["test_delay_epochs"] == 0 + +# Load the model and run inference +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path)