diff --git a/README.md b/README.md index 3c960af6..5a28bc8d 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,7 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models - SC-GlowTTS: [paper](https://arxiv.org/abs/2104.05557) - Capacitron: [paper](https://arxiv.org/abs/1906.03402) - OverFlow: [paper](https://arxiv.org/abs/2211.06892) +- Neural HMM TTS: [paper](https://arxiv.org/abs/2108.13320) ### End-to-End Models - VITS: [paper](https://arxiv.org/pdf/2106.06103) diff --git a/TTS/tts/configs/neuralhmm_tts_config.py b/TTS/tts/configs/neuralhmm_tts_config.py new file mode 100644 index 00000000..50f72847 --- /dev/null +++ b/TTS/tts/configs/neuralhmm_tts_config.py @@ -0,0 +1,170 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig + + +@dataclass +class NeuralhmmTTSConfig(BaseTTSConfig): + """ + Define parameters for Neural HMM TTS model. + + Example: + + >>> from TTS.tts.configs.overflow_config import OverflowConfig + >>> config = OverflowConfig() + + Args: + model (str): + Model name used to select the right model class to initilize. Defaults to `Overflow`. + run_eval_steps (int): + Run evalulation epoch after N steps. If None, waits until training epoch is completed. Defaults to None. + save_step (int): + Save local checkpoint every save_step steps. Defaults to 500. + plot_step (int): + Plot training stats on the logger every plot_step steps. Defaults to 1. + model_param_stats (bool): + Log model parameters stats on the logger dashboard. Defaults to False. + force_generate_statistics (bool): + Force generate mel normalization statistics. Defaults to False. + mel_statistics_parameter_path (str): + Path to the mel normalization statistics.If the model doesn't finds a file there it will generate statistics. + Defaults to None. + num_chars (int): + Number of characters used by the model. It must be defined before initializing the model. Defaults to None. + state_per_phone (int): + Generates N states per phone. Similar, to `add_blank` parameter in GlowTTS but in Overflow it is upsampled by model's encoder. Defaults to 2. + encoder_in_out_features (int): + Channels of encoder input and character embedding tensors. Defaults to 512. + encoder_n_convolutions (int): + Number of convolution layers in the encoder. Defaults to 3. + out_channels (int): + Channels of the final model output. It must match the spectragram size. Defaults to 80. + ar_order (int): + Autoregressive order of the model. Defaults to 1. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio. + sampling_temp (float): + Variation added to the sample from the latent space of neural HMM. Defaults to 0.334. + deterministic_transition (bool): + deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True. + duration_threshold (float): + Threshold for duration quantiles. Defaults to 0.55. Tune this to change the speaking rate of the synthesis, where lower values defines a slower speaking rate and higher values defines a faster speaking rate. + use_grad_checkpointing (bool): + Use gradient checkpointing to save memory. In a multi-GPU setting currently pytorch does not supports gradient checkpoint inside a loop so we will have to turn it off then.Adjust depending on whatever get more batch size either by using a single GPU or multi-GPU. Defaults to True. + max_sampling_time (int): + Maximum sampling time while synthesising latents from neural HMM. Defaults to 1000. + prenet_type (str): + `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the + Prenet. Defaults to `original`. + prenet_dim (int): + Dimension of the Prenet. Defaults to 256. + prenet_n_layers (int): + Number of layers in the Prenet. Defaults to 2. + prenet_dropout (float): + Dropout rate of the Prenet. Defaults to 0.5. + prenet_dropout_at_inference (bool): + Use dropout at inference time. Defaults to False. + memory_rnn_dim (int): + Dimension of the memory LSTM to process the prenet output. Defaults to 1024. + outputnet_size (list[int]): + Size of the output network inside the neural HMM. Defaults to [1024]. + flat_start_params (dict): + Parameters for the flat start initialization of the neural HMM. Defaults to `{"mean": 0.0, "std": 1.0, "transition_p": 0.14}`. + It will be recomputed when you pass the dataset. + std_floor (float): + Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint. Defaults to 0.01. + It is called `variance flooring` in standard HMM literature. + optimizer (str): + Optimizer to use for training. Defaults to `adam`. + optimizer_params (dict): + Parameters for the optimizer. Defaults to `{"weight_decay": 1e-6}`. + grad_clip (float): + Gradient clipping threshold. Defaults to 40_000. + lr (float): + Learning rate. Defaults to 1e-3. + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to `None`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "NeuralHMM_TTS" + + # Training and Checkpoint configs + run_eval_steps: int = 100 + save_step: int = 500 + plot_step: int = 1 + model_param_stats: bool = False + + # data parameters + force_generate_statistics: bool = False + mel_statistics_parameter_path: str = None + + # Encoder parameters + num_chars: int = None + state_per_phone: int = 2 + encoder_in_out_features: int = 512 + encoder_n_convolutions: int = 3 + + # HMM parameters + out_channels: int = 80 + ar_order: int = 1 + sampling_temp: float = 0 + deterministic_transition: bool = True + duration_threshold: float = 0.43 + use_grad_checkpointing: bool = True + max_sampling_time: int = 1000 + + ## Prenet parameters + prenet_type: str = "original" + prenet_dim: int = 256 + prenet_n_layers: int = 2 + prenet_dropout: float = 0.5 + prenet_dropout_at_inference: bool = True + memory_rnn_dim: int = 1024 + + ## Outputnet parameters + outputnet_size: List[int] = field(default_factory=lambda: [1024]) + flat_start_params: dict = field(default_factory=lambda: {"mean": 0.0, "std": 1.0, "transition_p": 0.14}) + std_floor: float = 0.001 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"weight_decay": 1e-6}) + grad_clip: float = 40000.0 + lr: float = 1e-3 + lr_scheduler: str = None + + # overrides + min_text_len: int = 10 + max_text_len: int = 500 + min_audio_len: int = 512 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "Be a voice, not an echo.", + ] + ) + + # Extra needed config + r: int = 1 + use_d_vector_file: bool = False + use_speaker_embedding: bool = False + + def check_values(self): + """Validate the hyperparameters. + + Raises: + AssertionError: when the parameters network is not defined + AssertionError: transition probability is not between 0 and 1 + """ + assert self.ar_order > 0, "AR order must be greater than 0 it is an autoregressive model." + assert ( + len(self.outputnet_size) >= 1 + ), f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}" + assert ( + 0 < self.flat_start_params["transition_p"] < 1 + ), f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}" diff --git a/TTS/tts/layers/overflow/plotting_utils.py b/TTS/tts/layers/overflow/plotting_utils.py index a9f9c301..a63aeb37 100644 --- a/TTS/tts/layers/overflow/plotting_utils.py +++ b/TTS/tts/layers/overflow/plotting_utils.py @@ -30,7 +30,7 @@ def validate_numpy_array(value: Any): return value -def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder): +def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder=None): """Get the most probable state means from the log_alpha_scaled. Args: @@ -38,16 +38,21 @@ def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder): - Shape: :math:`(T, N)` means (torch.Tensor): Means of the states. - Shape: :math:`(N, T, D_out)` - decoder (torch.nn.Module): Decoder module to decode the latent to melspectrogram + decoder (torch.nn.Module): Decoder module to decode the latent to melspectrogram. Defaults to None. """ max_state_numbers = torch.max(log_alpha_scaled, dim=1)[1] max_len = means.shape[0] n_mel_channels = means.shape[2] max_state_numbers = max_state_numbers.unsqueeze(1).unsqueeze(1).expand(max_len, 1, n_mel_channels) means = torch.gather(means, 1, max_state_numbers).squeeze(1).to(log_alpha_scaled.dtype) - mel = ( - decoder(means.T.unsqueeze(0), torch.tensor([means.shape[0]], device=means.device), reverse=True)[0].squeeze(0).T - ) + if decoder is not None: + mel = ( + decoder(means.T.unsqueeze(0), torch.tensor([means.shape[0]], device=means.device), reverse=True)[0] + .squeeze(0) + .T + ) + else: + mel = means return mel diff --git a/TTS/tts/models/neuralhmm_tts.py b/TTS/tts/models/neuralhmm_tts.py new file mode 100644 index 00000000..e4f88452 --- /dev/null +++ b/TTS/tts/models/neuralhmm_tts.py @@ -0,0 +1,384 @@ +import os +from typing import Dict, List, Union + +import torch +from coqpit import Coqpit +from torch import nn +from trainer.logging.tensorboard_logger import TensorboardLogger + +from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils +from TTS.tts.layers.overflow.neural_hmm import NeuralHMM +from TTS.tts.layers.overflow.plotting_utils import ( + get_spec_from_most_probable_state, + plot_transition_probabilities_to_numpy, +) +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.generic_utils import format_aux_input +from TTS.utils.io import load_fsspec + + +class NeuralhmmTTS(BaseTTS): + """Neural HMM TTS model. + + Paper:: + https://arxiv.org/abs/2108.13320 + + Paper abstract:: + Neural sequence-to-sequence TTS has achieved significantly better output quality + than statistical speech synthesis using HMMs.However, neural TTS is generally not probabilistic + and uses non-monotonic attention. Attention failures increase training time and can make + synthesis babble incoherently. This paper describes how the old and new paradigms can be + combined to obtain the advantages of both worlds, by replacing attention in neural TTS with + an autoregressive left-right no-skip hidden Markov model defined by a neural network. + Based on this proposal, we modify Tacotron 2 to obtain an HMM-based neural TTS model with + monotonic alignment, trained to maximise the full sequence likelihood without approximation. + We also describe how to combine ideas from classical and contemporary TTS for best results. + The resulting example system is smaller and simpler than Tacotron 2, and learns to speak with + fewer iterations and less data, whilst achieving comparable naturalness prior to the post-net. + Our approach also allows easy control over speaking rate. Audio examples and code + are available at https://shivammehta25.github.io/Neural-HMM/ . + + Note: + - This is a parameter efficient version of OverFlow (15.3M vs 28.6M). Since it has half the + number of parameters as OverFlow the synthesis output quality is suboptimal (but comparable to Tacotron2 + without Postnet), but it learns to speak with even lesser amount of data and is still significantly faster + than other attention-based methods. + + - Neural HMMs uses flat start initialization i.e it computes the means and std and transition probabilities + of the dataset and uses them to initialize the model. This benefits the model and helps with faster learning + If you change the dataset or want to regenerate the parameters change the `force_generate_statistics` and + `mel_statistics_parameter_path` accordingly. + + - To enable multi-GPU training, set the `use_grad_checkpointing=False` in config. + This will significantly increase the memory usage. This is because to compute + the actual data likelihood (not an approximation using MAS/Viterbi) we must use + all the states at the previous time step during the forward pass to decide the + probability distribution at the current step i.e the difference between the forward + algorithm and viterbi approximation. + + Check :class:`TTS.tts.configs.neuralhmm_tts_config.NeuralhmmTTSConfig` for class arguments. + """ + + def __init__( + self, + config: "NeuralhmmTTSConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + self.config = config + for key in config: + setattr(self, key, config[key]) + + self.encoder = Encoder(config.num_chars, config.state_per_phone, config.encoder_in_out_features) + self.neural_hmm = NeuralHMM( + frame_channels=self.out_channels, + ar_order=self.ar_order, + deterministic_transition=self.deterministic_transition, + encoder_dim=self.encoder_in_out_features, + prenet_type=self.prenet_type, + prenet_dim=self.prenet_dim, + prenet_n_layers=self.prenet_n_layers, + prenet_dropout=self.prenet_dropout, + prenet_dropout_at_inference=self.prenet_dropout_at_inference, + memory_rnn_dim=self.memory_rnn_dim, + outputnet_size=self.outputnet_size, + flat_start_params=self.flat_start_params, + std_floor=self.std_floor, + use_grad_checkpointing=self.use_grad_checkpointing, + ) + + self.register_buffer("mean", torch.tensor(0)) + self.register_buffer("std", torch.tensor(1)) + + def update_mean_std(self, statistics_dict: Dict): + self.mean.data = torch.tensor(statistics_dict["mean"]) + self.std.data = torch.tensor(statistics_dict["std"]) + + def preprocess_batch(self, text, text_len, mels, mel_len): + if self.mean.item() == 0 or self.std.item() == 1: + statistics_dict = torch.load(self.mel_statistics_parameter_path) + self.update_mean_std(statistics_dict) + + mels = self.normalize(mels) + return text, text_len, mels, mel_len + + def normalize(self, x): + return x.sub(self.mean).div(self.std) + + def inverse_normalize(self, x): + return x.mul(self.std).add(self.mean) + + def forward(self, text, text_len, mels, mel_len): + """ + Forward pass for training and computing the log likelihood of a given batch. + + Shapes: + Shapes: + text: :math:`[B, T_in]` + text_len: :math:`[B]` + mels: :math:`[B, T_out, C]` + mel_len: :math:`[B]` + """ + text, text_len, mels, mel_len = self.preprocess_batch(text, text_len, mels, mel_len) + encoder_outputs, encoder_output_len = self.encoder(text, text_len) + + log_probs, fwd_alignments, transition_vectors, means = self.neural_hmm( + encoder_outputs, encoder_output_len, mels.transpose(1, 2), mel_len + ) + + outputs = { + "log_probs": log_probs, + "alignments": fwd_alignments, + "transition_vectors": transition_vectors, + "means": means, + } + + return outputs + + @staticmethod + def _training_stats(batch): + stats = {} + stats["avg_text_length"] = batch["text_lengths"].float().mean() + stats["avg_spec_length"] = batch["mel_lengths"].float().mean() + stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean() + stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean() + return stats + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + + outputs = self.forward( + text=text_input, + text_len=text_lengths, + mels=mel_input, + mel_len=mel_lengths, + ) + loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum())) + + # for printing useful statistics on terminal + loss_dict.update(self._training_stats(batch)) + return outputs, loss_dict + + def eval_step(self, batch: Dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def _format_aux_input(self, aux_input: Dict, default_input_dict): + """Set missing fields to their default value. + + Args: + aux_inputs (Dict): Dictionary containing the auxiliary inputs. + """ + default_input_dict.update( + { + "sampling_temp": self.sampling_temp, + "max_sampling_time": self.max_sampling_time, + "duration_threshold": self.duration_threshold, + } + ) + if aux_input: + return format_aux_input(aux_input, default_input_dict) + return None + + @torch.no_grad() + def inference( + self, + text: torch.Tensor, + aux_input={"x_lengths": None, "sampling_temp": None, "max_sampling_time": None, "duration_threshold": None}, + ): # pylint: disable=dangerous-default-value + """Sampling from the model + + Args: + text (torch.Tensor): :math:`[B, T_in]` + aux_inputs (_type_, optional): _description_. Defaults to None. + + Returns: + outputs: Dictionary containing the following + - mel (torch.Tensor): :math:`[B, T_out, C]` + - hmm_outputs_len (torch.Tensor): :math:`[B]` + - state_travelled (List[List[int]]): List of lists containing the state travelled for each sample in the batch. + - input_parameters (list[torch.FloatTensor]): Input parameters to the neural HMM. + - output_parameters (list[torch.FloatTensor]): Output parameters to the neural HMM. + """ + default_input_dict = { + "x_lengths": torch.sum(text != 0, dim=1), + } + aux_input = self._format_aux_input(aux_input, default_input_dict) + encoder_outputs, encoder_output_len = self.encoder.inference(text, aux_input["x_lengths"]) + outputs = self.neural_hmm.inference( + encoder_outputs, + encoder_output_len, + sampling_temp=aux_input["sampling_temp"], + max_sampling_time=aux_input["max_sampling_time"], + duration_threshold=aux_input["duration_threshold"], + ) + mels, mel_outputs_len = outputs["hmm_outputs"], outputs["hmm_outputs_len"] + + mels = self.inverse_normalize(mels) + outputs.update({"model_outputs": mels, "model_outputs_len": mel_outputs_len}) + outputs["alignments"] = OverflowUtils.double_pad(outputs["alignments"]) + return outputs + + @staticmethod + def get_criterion(): + return NLLLoss() + + @staticmethod + def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + verbose (bool): If True, print init messages. Defaults to True. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config, verbose) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return NeuralhmmTTS(new_config, ap, tokenizer, speaker_manager) + + def load_checkpoint( + self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def on_init_start(self, trainer): + """If the current dataset does not have normalisation statistics and initialisation transition_probability it computes them otherwise loads.""" + if not os.path.isfile(trainer.config.mel_statistics_parameter_path) or trainer.config.force_generate_statistics: + dataloader = trainer.get_train_dataloader( + training_assets=None, samples=trainer.train_samples, verbose=False + ) + print( + f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..." + ) + data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( + dataloader, trainer.config.out_channels, trainer.config.state_per_phone + ) + print( + f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}" + ) + statistics = { + "mean": data_mean.item(), + "std": data_std.item(), + "init_transition_prob": init_transition_prob.item(), + } + torch.save(statistics, trainer.config.mel_statistics_parameter_path) + + else: + print( + f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..." + ) + statistics = torch.load(trainer.config.mel_statistics_parameter_path) + data_mean, data_std, init_transition_prob = ( + statistics["mean"], + statistics["std"], + statistics["init_transition_prob"], + ) + print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}") + + trainer.config.flat_start_params["transition_p"] = ( + init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob + ) + OverflowUtils.update_flat_start_transition(trainer.model, init_transition_prob) + trainer.model.update_mean_std(statistics) + + @torch.inference_mode() + def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use, unused-argument + alignments, transition_vectors = outputs["alignments"], outputs["transition_vectors"] + means = torch.stack(outputs["means"], dim=1) + + figures = { + "alignment": plot_alignment(alignments[0].exp(), title="Forward alignment", fig_size=(20, 20)), + "log_alignment": plot_alignment( + alignments[0].exp(), title="Forward log alignment", plot_log=True, fig_size=(20, 20) + ), + "transition_vectors": plot_alignment(transition_vectors[0], title="Transition vectors", fig_size=(20, 20)), + "mel_from_most_probable_state": plot_spectrogram( + get_spec_from_most_probable_state(alignments[0], means[0]), fig_size=(12, 3) + ), + "mel_target": plot_spectrogram(batch["mel_input"][0], fig_size=(12, 3)), + } + + # sample one item from the batch -1 will give the smalles item + print(" | > Synthesising audio from the model...") + inference_output = self.inference( + batch["text_input"][-1].unsqueeze(0), aux_input={"x_lenghts": batch["text_lengths"][-1].unsqueeze(0)} + ) + figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3)) + + states = [p[1] for p in inference_output["input_parameters"][0]] + transition_probability_synthesising = [p[2].cpu().numpy() for p in inference_output["output_parameters"][0]] + + for i in range((len(transition_probability_synthesising) // 200) + 1): + start = i * 200 + end = (i + 1) * 200 + figures[f"synthesised_transition_probabilities/{i}"] = plot_transition_probabilities_to_numpy( + states[start:end], transition_probability_synthesising[start:end] + ) + + audio = ap.inv_melspectrogram(inference_output["model_outputs"][0].T.cpu().numpy()) + return figures, {"audios": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=unused-argument + """Log training progress.""" + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int + ): # pylint: disable=unused-argument + """Compute and log evaluation metrics.""" + # Plot model parameters histograms + if isinstance(logger, TensorboardLogger): + # I don't know if any other loggers supports this + for tag, value in self.named_parameters(): + tag = tag.replace(".", "/") + logger.writer.add_histogram(tag, value.data.cpu().numpy(), steps) + + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs[1], self.ap.sample_rate) + logger.test_figures(steps, outputs[0]) + + +class NLLLoss(nn.Module): + """Negative log likelihood loss.""" + + def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use + """Compute the loss. + + Args: + logits (Tensor): [B, T, D] + + Returns: + Tensor: [1] + + """ + return_dict = {} + return_dict["loss"] = -log_prob.mean() + return return_dict diff --git a/TTS/tts/models/overflow.py b/TTS/tts/models/overflow.py index ee5ff411..c2b5b7c2 100644 --- a/TTS/tts/models/overflow.py +++ b/TTS/tts/models/overflow.py @@ -111,9 +111,6 @@ class Overflow(BaseTTS): self.register_buffer("mean", torch.tensor(0)) self.register_buffer("std", torch.tensor(1)) - # self.mean = nn.Parameter(torch.zeros(1), requires_grad=False) - # self.std = nn.Parameter(torch.ones(1), requires_grad=False) - def update_mean_std(self, statistics_dict: Dict): self.mean.data = torch.tensor(statistics_dict["mean"]) self.std.data = torch.tensor(statistics_dict["std"]) diff --git a/recipes/ljspeech/neuralhmm_tts/train_neuralhmmtts.py b/recipes/ljspeech/neuralhmm_tts/train_neuralhmmtts.py new file mode 100644 index 00000000..28d37799 --- /dev/null +++ b/recipes/ljspeech/neuralhmm_tts/train_neuralhmmtts.py @@ -0,0 +1,96 @@ +import os + +from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseAudioConfig +from TTS.tts.configs.neuralhmm_tts_config import NeuralhmmTTSConfig +from TTS.tts.configs.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models.neuralhmm_tts import NeuralhmmTTS +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.utils.audio import AudioProcessor + +output_path = os.path.dirname(os.path.abspath(__file__)) + +# init configs +dataset_config = BaseDatasetConfig( + formatter="ljspeech", meta_file_train="metadata.csv", path=os.path.join("data", "LJSpeech-1.1/") +) + +audio_config = BaseAudioConfig( + sample_rate=22050, + do_trim_silence=True, + trim_db=60.0, + signal_norm=False, + mel_fmin=0.0, + mel_fmax=8000, + spec_gain=1.0, + log_func="np.log", + ref_level_db=20, + preemphasis=0.0, +) + +config = NeuralhmmTTSConfig( # This is the config that is saved for the future use + run_name="neuralhmmtts_ljspeech", + audio=audio_config, + batch_size=32, + eval_batch_size=16, + num_loader_workers=4, + num_eval_loader_workers=4, + run_eval=True, + test_delay_epochs=-1, + epochs=1000, + text_cleaner="phoneme_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + precompute_num_workers=8, + mel_statistics_parameter_path=os.path.join(output_path, "lj_parameters.pt"), + force_generate_statistics=False, + print_step=1, + print_eval=True, + mixed_precision=True, + output_path=output_path, + datasets=[dataset_config], +) + +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) + +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. +train_samples, eval_samples = load_tts_samples( + dataset_config, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, +) + +# INITIALIZE THE MODEL +# Models take a config object and a speaker manager as input +# Config defines the details of the model like the number of layers, the size of the embedding, etc. +# Speaker manager is used by multi-speaker models. +model = NeuralhmmTTS(config, ap, tokenizer) + + +# init the trainer and 🚀 +trainer = Trainer( + TrainerArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + gpu=1, +) +trainer.fit() diff --git a/tests/tts_tests/test_neuralhmm_tts_train.py b/tests/tts_tests/test_neuralhmm_tts_train.py new file mode 100644 index 00000000..25d9aa81 --- /dev/null +++ b/tests/tts_tests/test_neuralhmm_tts_train.py @@ -0,0 +1,92 @@ +import glob +import json +import os +import shutil + +import torch +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.neuralhmm_tts_config import NeuralhmmTTSConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") +parameter_path = os.path.join(get_tests_output_path(), "lj_parameters.pt") + +torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path) + +config = NeuralhmmTTSConfig( + batch_size=3, + eval_batch_size=3, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="phoneme_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"), + run_eval=True, + test_delay_epochs=-1, + mel_statistics_parameter_path=parameter_path, + epochs=1, + print_step=1, + test_sentences=[ + "Be a voice, not an echo.", + ], + print_eval=True, + max_sampling_time=50, +) +config.audio.do_trim_silence = True +config.audio.trim_db = 60 +config.save_json(config_path) + + +# train the model for one epoch when mel parameters exists +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.formatter ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.test_delay_epochs 0 " +) +run_cli(command_train) + + +# train the model for one epoch when mel parameters have to be computed from the dataset +if os.path.exists(parameter_path): + os.remove(parameter_path) +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.formatter ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.test_delay_epochs 0 " +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") + +# Check integrity of the config +with open(continue_config_path, "r", encoding="utf-8") as f: + config_loaded = json.load(f) +assert config_loaded["characters"] is not None +assert config_loaded["output_path"] in continue_path +assert config_loaded["test_delay_epochs"] == 0 + +# Load the model and run inference +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path)