mirror of https://github.com/coqui-ai/TTS.git
Adding neural HMM TTS Model (#2272)
* Adding neural HMM TTS * Adding tests * Adding neural hmm on readme * renaming training recipe * Removing overflow\s decoder parameters from the config * Update the Trainer requirement version for a compatible one (#2276) * Bump up to v0.10.2 * Adding neural HMM TTS * Adding tests * Adding neural hmm on readme * renaming training recipe * Removing overflow\s decoder parameters from the config * fixing documentation Co-authored-by: Edresson Casanova <edresson1@gmail.com> Co-authored-by: Eren Gölge <erogol@hotmail.com>
This commit is contained in:
parent
497f22b20b
commit
d83ee8fe45
|
@ -95,6 +95,7 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
|
|||
- SC-GlowTTS: [paper](https://arxiv.org/abs/2104.05557)
|
||||
- Capacitron: [paper](https://arxiv.org/abs/1906.03402)
|
||||
- OverFlow: [paper](https://arxiv.org/abs/2211.06892)
|
||||
- Neural HMM TTS: [paper](https://arxiv.org/abs/2108.13320)
|
||||
|
||||
### End-to-End Models
|
||||
- VITS: [paper](https://arxiv.org/pdf/2106.06103)
|
||||
|
|
|
@ -0,0 +1,170 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class NeuralhmmTTSConfig(BaseTTSConfig):
|
||||
"""
|
||||
Define parameters for Neural HMM TTS model.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.tts.configs.overflow_config import OverflowConfig
|
||||
>>> config = OverflowConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used to select the right model class to initilize. Defaults to `Overflow`.
|
||||
run_eval_steps (int):
|
||||
Run evalulation epoch after N steps. If None, waits until training epoch is completed. Defaults to None.
|
||||
save_step (int):
|
||||
Save local checkpoint every save_step steps. Defaults to 500.
|
||||
plot_step (int):
|
||||
Plot training stats on the logger every plot_step steps. Defaults to 1.
|
||||
model_param_stats (bool):
|
||||
Log model parameters stats on the logger dashboard. Defaults to False.
|
||||
force_generate_statistics (bool):
|
||||
Force generate mel normalization statistics. Defaults to False.
|
||||
mel_statistics_parameter_path (str):
|
||||
Path to the mel normalization statistics.If the model doesn't finds a file there it will generate statistics.
|
||||
Defaults to None.
|
||||
num_chars (int):
|
||||
Number of characters used by the model. It must be defined before initializing the model. Defaults to None.
|
||||
state_per_phone (int):
|
||||
Generates N states per phone. Similar, to `add_blank` parameter in GlowTTS but in Overflow it is upsampled by model's encoder. Defaults to 2.
|
||||
encoder_in_out_features (int):
|
||||
Channels of encoder input and character embedding tensors. Defaults to 512.
|
||||
encoder_n_convolutions (int):
|
||||
Number of convolution layers in the encoder. Defaults to 3.
|
||||
out_channels (int):
|
||||
Channels of the final model output. It must match the spectragram size. Defaults to 80.
|
||||
ar_order (int):
|
||||
Autoregressive order of the model. Defaults to 1. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio.
|
||||
sampling_temp (float):
|
||||
Variation added to the sample from the latent space of neural HMM. Defaults to 0.334.
|
||||
deterministic_transition (bool):
|
||||
deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True.
|
||||
duration_threshold (float):
|
||||
Threshold for duration quantiles. Defaults to 0.55. Tune this to change the speaking rate of the synthesis, where lower values defines a slower speaking rate and higher values defines a faster speaking rate.
|
||||
use_grad_checkpointing (bool):
|
||||
Use gradient checkpointing to save memory. In a multi-GPU setting currently pytorch does not supports gradient checkpoint inside a loop so we will have to turn it off then.Adjust depending on whatever get more batch size either by using a single GPU or multi-GPU. Defaults to True.
|
||||
max_sampling_time (int):
|
||||
Maximum sampling time while synthesising latents from neural HMM. Defaults to 1000.
|
||||
prenet_type (str):
|
||||
`original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the
|
||||
Prenet. Defaults to `original`.
|
||||
prenet_dim (int):
|
||||
Dimension of the Prenet. Defaults to 256.
|
||||
prenet_n_layers (int):
|
||||
Number of layers in the Prenet. Defaults to 2.
|
||||
prenet_dropout (float):
|
||||
Dropout rate of the Prenet. Defaults to 0.5.
|
||||
prenet_dropout_at_inference (bool):
|
||||
Use dropout at inference time. Defaults to False.
|
||||
memory_rnn_dim (int):
|
||||
Dimension of the memory LSTM to process the prenet output. Defaults to 1024.
|
||||
outputnet_size (list[int]):
|
||||
Size of the output network inside the neural HMM. Defaults to [1024].
|
||||
flat_start_params (dict):
|
||||
Parameters for the flat start initialization of the neural HMM. Defaults to `{"mean": 0.0, "std": 1.0, "transition_p": 0.14}`.
|
||||
It will be recomputed when you pass the dataset.
|
||||
std_floor (float):
|
||||
Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint. Defaults to 0.01.
|
||||
It is called `variance flooring` in standard HMM literature.
|
||||
optimizer (str):
|
||||
Optimizer to use for training. Defaults to `adam`.
|
||||
optimizer_params (dict):
|
||||
Parameters for the optimizer. Defaults to `{"weight_decay": 1e-6}`.
|
||||
grad_clip (float):
|
||||
Gradient clipping threshold. Defaults to 40_000.
|
||||
lr (float):
|
||||
Learning rate. Defaults to 1e-3.
|
||||
lr_scheduler (str):
|
||||
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
|
||||
`TTS.utils.training`. Defaults to `None`.
|
||||
min_seq_len (int):
|
||||
Minimum input sequence length to be used at training.
|
||||
max_seq_len (int):
|
||||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||
"""
|
||||
|
||||
model: str = "NeuralHMM_TTS"
|
||||
|
||||
# Training and Checkpoint configs
|
||||
run_eval_steps: int = 100
|
||||
save_step: int = 500
|
||||
plot_step: int = 1
|
||||
model_param_stats: bool = False
|
||||
|
||||
# data parameters
|
||||
force_generate_statistics: bool = False
|
||||
mel_statistics_parameter_path: str = None
|
||||
|
||||
# Encoder parameters
|
||||
num_chars: int = None
|
||||
state_per_phone: int = 2
|
||||
encoder_in_out_features: int = 512
|
||||
encoder_n_convolutions: int = 3
|
||||
|
||||
# HMM parameters
|
||||
out_channels: int = 80
|
||||
ar_order: int = 1
|
||||
sampling_temp: float = 0
|
||||
deterministic_transition: bool = True
|
||||
duration_threshold: float = 0.43
|
||||
use_grad_checkpointing: bool = True
|
||||
max_sampling_time: int = 1000
|
||||
|
||||
## Prenet parameters
|
||||
prenet_type: str = "original"
|
||||
prenet_dim: int = 256
|
||||
prenet_n_layers: int = 2
|
||||
prenet_dropout: float = 0.5
|
||||
prenet_dropout_at_inference: bool = True
|
||||
memory_rnn_dim: int = 1024
|
||||
|
||||
## Outputnet parameters
|
||||
outputnet_size: List[int] = field(default_factory=lambda: [1024])
|
||||
flat_start_params: dict = field(default_factory=lambda: {"mean": 0.0, "std": 1.0, "transition_p": 0.14})
|
||||
std_floor: float = 0.001
|
||||
|
||||
# optimizer parameters
|
||||
optimizer: str = "Adam"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"weight_decay": 1e-6})
|
||||
grad_clip: float = 40000.0
|
||||
lr: float = 1e-3
|
||||
lr_scheduler: str = None
|
||||
|
||||
# overrides
|
||||
min_text_len: int = 10
|
||||
max_text_len: int = 500
|
||||
min_audio_len: int = 512
|
||||
|
||||
# testing
|
||||
test_sentences: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"Be a voice, not an echo.",
|
||||
]
|
||||
)
|
||||
|
||||
# Extra needed config
|
||||
r: int = 1
|
||||
use_d_vector_file: bool = False
|
||||
use_speaker_embedding: bool = False
|
||||
|
||||
def check_values(self):
|
||||
"""Validate the hyperparameters.
|
||||
|
||||
Raises:
|
||||
AssertionError: when the parameters network is not defined
|
||||
AssertionError: transition probability is not between 0 and 1
|
||||
"""
|
||||
assert self.ar_order > 0, "AR order must be greater than 0 it is an autoregressive model."
|
||||
assert (
|
||||
len(self.outputnet_size) >= 1
|
||||
), f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}"
|
||||
assert (
|
||||
0 < self.flat_start_params["transition_p"] < 1
|
||||
), f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}"
|
|
@ -30,7 +30,7 @@ def validate_numpy_array(value: Any):
|
|||
return value
|
||||
|
||||
|
||||
def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder):
|
||||
def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder=None):
|
||||
"""Get the most probable state means from the log_alpha_scaled.
|
||||
|
||||
Args:
|
||||
|
@ -38,16 +38,21 @@ def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder):
|
|||
- Shape: :math:`(T, N)`
|
||||
means (torch.Tensor): Means of the states.
|
||||
- Shape: :math:`(N, T, D_out)`
|
||||
decoder (torch.nn.Module): Decoder module to decode the latent to melspectrogram
|
||||
decoder (torch.nn.Module): Decoder module to decode the latent to melspectrogram. Defaults to None.
|
||||
"""
|
||||
max_state_numbers = torch.max(log_alpha_scaled, dim=1)[1]
|
||||
max_len = means.shape[0]
|
||||
n_mel_channels = means.shape[2]
|
||||
max_state_numbers = max_state_numbers.unsqueeze(1).unsqueeze(1).expand(max_len, 1, n_mel_channels)
|
||||
means = torch.gather(means, 1, max_state_numbers).squeeze(1).to(log_alpha_scaled.dtype)
|
||||
mel = (
|
||||
decoder(means.T.unsqueeze(0), torch.tensor([means.shape[0]], device=means.device), reverse=True)[0].squeeze(0).T
|
||||
)
|
||||
if decoder is not None:
|
||||
mel = (
|
||||
decoder(means.T.unsqueeze(0), torch.tensor([means.shape[0]], device=means.device), reverse=True)[0]
|
||||
.squeeze(0)
|
||||
.T
|
||||
)
|
||||
else:
|
||||
mel = means
|
||||
return mel
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,384 @@
|
|||
import os
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from trainer.logging.tensorboard_logger import TensorboardLogger
|
||||
|
||||
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
|
||||
from TTS.tts.layers.overflow.neural_hmm import NeuralHMM
|
||||
from TTS.tts.layers.overflow.plotting_utils import (
|
||||
get_spec_from_most_probable_state,
|
||||
plot_transition_probabilities_to_numpy,
|
||||
)
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
class NeuralhmmTTS(BaseTTS):
|
||||
"""Neural HMM TTS model.
|
||||
|
||||
Paper::
|
||||
https://arxiv.org/abs/2108.13320
|
||||
|
||||
Paper abstract::
|
||||
Neural sequence-to-sequence TTS has achieved significantly better output quality
|
||||
than statistical speech synthesis using HMMs.However, neural TTS is generally not probabilistic
|
||||
and uses non-monotonic attention. Attention failures increase training time and can make
|
||||
synthesis babble incoherently. This paper describes how the old and new paradigms can be
|
||||
combined to obtain the advantages of both worlds, by replacing attention in neural TTS with
|
||||
an autoregressive left-right no-skip hidden Markov model defined by a neural network.
|
||||
Based on this proposal, we modify Tacotron 2 to obtain an HMM-based neural TTS model with
|
||||
monotonic alignment, trained to maximise the full sequence likelihood without approximation.
|
||||
We also describe how to combine ideas from classical and contemporary TTS for best results.
|
||||
The resulting example system is smaller and simpler than Tacotron 2, and learns to speak with
|
||||
fewer iterations and less data, whilst achieving comparable naturalness prior to the post-net.
|
||||
Our approach also allows easy control over speaking rate. Audio examples and code
|
||||
are available at https://shivammehta25.github.io/Neural-HMM/ .
|
||||
|
||||
Note:
|
||||
- This is a parameter efficient version of OverFlow (15.3M vs 28.6M). Since it has half the
|
||||
number of parameters as OverFlow the synthesis output quality is suboptimal (but comparable to Tacotron2
|
||||
without Postnet), but it learns to speak with even lesser amount of data and is still significantly faster
|
||||
than other attention-based methods.
|
||||
|
||||
- Neural HMMs uses flat start initialization i.e it computes the means and std and transition probabilities
|
||||
of the dataset and uses them to initialize the model. This benefits the model and helps with faster learning
|
||||
If you change the dataset or want to regenerate the parameters change the `force_generate_statistics` and
|
||||
`mel_statistics_parameter_path` accordingly.
|
||||
|
||||
- To enable multi-GPU training, set the `use_grad_checkpointing=False` in config.
|
||||
This will significantly increase the memory usage. This is because to compute
|
||||
the actual data likelihood (not an approximation using MAS/Viterbi) we must use
|
||||
all the states at the previous time step during the forward pass to decide the
|
||||
probability distribution at the current step i.e the difference between the forward
|
||||
algorithm and viterbi approximation.
|
||||
|
||||
Check :class:`TTS.tts.configs.neuralhmm_tts_config.NeuralhmmTTSConfig` for class arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: "NeuralhmmTTSConfig",
|
||||
ap: "AudioProcessor" = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
):
|
||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||
|
||||
# pass all config fields to `self`
|
||||
# for fewer code change
|
||||
self.config = config
|
||||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
self.encoder = Encoder(config.num_chars, config.state_per_phone, config.encoder_in_out_features)
|
||||
self.neural_hmm = NeuralHMM(
|
||||
frame_channels=self.out_channels,
|
||||
ar_order=self.ar_order,
|
||||
deterministic_transition=self.deterministic_transition,
|
||||
encoder_dim=self.encoder_in_out_features,
|
||||
prenet_type=self.prenet_type,
|
||||
prenet_dim=self.prenet_dim,
|
||||
prenet_n_layers=self.prenet_n_layers,
|
||||
prenet_dropout=self.prenet_dropout,
|
||||
prenet_dropout_at_inference=self.prenet_dropout_at_inference,
|
||||
memory_rnn_dim=self.memory_rnn_dim,
|
||||
outputnet_size=self.outputnet_size,
|
||||
flat_start_params=self.flat_start_params,
|
||||
std_floor=self.std_floor,
|
||||
use_grad_checkpointing=self.use_grad_checkpointing,
|
||||
)
|
||||
|
||||
self.register_buffer("mean", torch.tensor(0))
|
||||
self.register_buffer("std", torch.tensor(1))
|
||||
|
||||
def update_mean_std(self, statistics_dict: Dict):
|
||||
self.mean.data = torch.tensor(statistics_dict["mean"])
|
||||
self.std.data = torch.tensor(statistics_dict["std"])
|
||||
|
||||
def preprocess_batch(self, text, text_len, mels, mel_len):
|
||||
if self.mean.item() == 0 or self.std.item() == 1:
|
||||
statistics_dict = torch.load(self.mel_statistics_parameter_path)
|
||||
self.update_mean_std(statistics_dict)
|
||||
|
||||
mels = self.normalize(mels)
|
||||
return text, text_len, mels, mel_len
|
||||
|
||||
def normalize(self, x):
|
||||
return x.sub(self.mean).div(self.std)
|
||||
|
||||
def inverse_normalize(self, x):
|
||||
return x.mul(self.std).add(self.mean)
|
||||
|
||||
def forward(self, text, text_len, mels, mel_len):
|
||||
"""
|
||||
Forward pass for training and computing the log likelihood of a given batch.
|
||||
|
||||
Shapes:
|
||||
Shapes:
|
||||
text: :math:`[B, T_in]`
|
||||
text_len: :math:`[B]`
|
||||
mels: :math:`[B, T_out, C]`
|
||||
mel_len: :math:`[B]`
|
||||
"""
|
||||
text, text_len, mels, mel_len = self.preprocess_batch(text, text_len, mels, mel_len)
|
||||
encoder_outputs, encoder_output_len = self.encoder(text, text_len)
|
||||
|
||||
log_probs, fwd_alignments, transition_vectors, means = self.neural_hmm(
|
||||
encoder_outputs, encoder_output_len, mels.transpose(1, 2), mel_len
|
||||
)
|
||||
|
||||
outputs = {
|
||||
"log_probs": log_probs,
|
||||
"alignments": fwd_alignments,
|
||||
"transition_vectors": transition_vectors,
|
||||
"means": means,
|
||||
}
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _training_stats(batch):
|
||||
stats = {}
|
||||
stats["avg_text_length"] = batch["text_lengths"].float().mean()
|
||||
stats["avg_spec_length"] = batch["mel_lengths"].float().mean()
|
||||
stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean()
|
||||
stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean()
|
||||
return stats
|
||||
|
||||
def train_step(self, batch: dict, criterion: nn.Module):
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
mel_input = batch["mel_input"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
|
||||
outputs = self.forward(
|
||||
text=text_input,
|
||||
text_len=text_lengths,
|
||||
mels=mel_input,
|
||||
mel_len=mel_lengths,
|
||||
)
|
||||
loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum()))
|
||||
|
||||
# for printing useful statistics on terminal
|
||||
loss_dict.update(self._training_stats(batch))
|
||||
return outputs, loss_dict
|
||||
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def _format_aux_input(self, aux_input: Dict, default_input_dict):
|
||||
"""Set missing fields to their default value.
|
||||
|
||||
Args:
|
||||
aux_inputs (Dict): Dictionary containing the auxiliary inputs.
|
||||
"""
|
||||
default_input_dict.update(
|
||||
{
|
||||
"sampling_temp": self.sampling_temp,
|
||||
"max_sampling_time": self.max_sampling_time,
|
||||
"duration_threshold": self.duration_threshold,
|
||||
}
|
||||
)
|
||||
if aux_input:
|
||||
return format_aux_input(aux_input, default_input_dict)
|
||||
return None
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
aux_input={"x_lengths": None, "sampling_temp": None, "max_sampling_time": None, "duration_threshold": None},
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""Sampling from the model
|
||||
|
||||
Args:
|
||||
text (torch.Tensor): :math:`[B, T_in]`
|
||||
aux_inputs (_type_, optional): _description_. Defaults to None.
|
||||
|
||||
Returns:
|
||||
outputs: Dictionary containing the following
|
||||
- mel (torch.Tensor): :math:`[B, T_out, C]`
|
||||
- hmm_outputs_len (torch.Tensor): :math:`[B]`
|
||||
- state_travelled (List[List[int]]): List of lists containing the state travelled for each sample in the batch.
|
||||
- input_parameters (list[torch.FloatTensor]): Input parameters to the neural HMM.
|
||||
- output_parameters (list[torch.FloatTensor]): Output parameters to the neural HMM.
|
||||
"""
|
||||
default_input_dict = {
|
||||
"x_lengths": torch.sum(text != 0, dim=1),
|
||||
}
|
||||
aux_input = self._format_aux_input(aux_input, default_input_dict)
|
||||
encoder_outputs, encoder_output_len = self.encoder.inference(text, aux_input["x_lengths"])
|
||||
outputs = self.neural_hmm.inference(
|
||||
encoder_outputs,
|
||||
encoder_output_len,
|
||||
sampling_temp=aux_input["sampling_temp"],
|
||||
max_sampling_time=aux_input["max_sampling_time"],
|
||||
duration_threshold=aux_input["duration_threshold"],
|
||||
)
|
||||
mels, mel_outputs_len = outputs["hmm_outputs"], outputs["hmm_outputs_len"]
|
||||
|
||||
mels = self.inverse_normalize(mels)
|
||||
outputs.update({"model_outputs": mels, "model_outputs_len": mel_outputs_len})
|
||||
outputs["alignments"] = OverflowUtils.double_pad(outputs["alignments"])
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def get_criterion():
|
||||
return NLLLoss()
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (VitsConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
verbose (bool): If True, print init messages. Defaults to True.
|
||||
"""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config, verbose)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
return NeuralhmmTTS(new_config, ap, tokenizer, speaker_manager)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
def on_init_start(self, trainer):
|
||||
"""If the current dataset does not have normalisation statistics and initialisation transition_probability it computes them otherwise loads."""
|
||||
if not os.path.isfile(trainer.config.mel_statistics_parameter_path) or trainer.config.force_generate_statistics:
|
||||
dataloader = trainer.get_train_dataloader(
|
||||
training_assets=None, samples=trainer.train_samples, verbose=False
|
||||
)
|
||||
print(
|
||||
f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..."
|
||||
)
|
||||
data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start(
|
||||
dataloader, trainer.config.out_channels, trainer.config.state_per_phone
|
||||
)
|
||||
print(
|
||||
f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}"
|
||||
)
|
||||
statistics = {
|
||||
"mean": data_mean.item(),
|
||||
"std": data_std.item(),
|
||||
"init_transition_prob": init_transition_prob.item(),
|
||||
}
|
||||
torch.save(statistics, trainer.config.mel_statistics_parameter_path)
|
||||
|
||||
else:
|
||||
print(
|
||||
f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..."
|
||||
)
|
||||
statistics = torch.load(trainer.config.mel_statistics_parameter_path)
|
||||
data_mean, data_std, init_transition_prob = (
|
||||
statistics["mean"],
|
||||
statistics["std"],
|
||||
statistics["init_transition_prob"],
|
||||
)
|
||||
print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}")
|
||||
|
||||
trainer.config.flat_start_params["transition_p"] = (
|
||||
init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob
|
||||
)
|
||||
OverflowUtils.update_flat_start_transition(trainer.model, init_transition_prob)
|
||||
trainer.model.update_mean_std(statistics)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use, unused-argument
|
||||
alignments, transition_vectors = outputs["alignments"], outputs["transition_vectors"]
|
||||
means = torch.stack(outputs["means"], dim=1)
|
||||
|
||||
figures = {
|
||||
"alignment": plot_alignment(alignments[0].exp(), title="Forward alignment", fig_size=(20, 20)),
|
||||
"log_alignment": plot_alignment(
|
||||
alignments[0].exp(), title="Forward log alignment", plot_log=True, fig_size=(20, 20)
|
||||
),
|
||||
"transition_vectors": plot_alignment(transition_vectors[0], title="Transition vectors", fig_size=(20, 20)),
|
||||
"mel_from_most_probable_state": plot_spectrogram(
|
||||
get_spec_from_most_probable_state(alignments[0], means[0]), fig_size=(12, 3)
|
||||
),
|
||||
"mel_target": plot_spectrogram(batch["mel_input"][0], fig_size=(12, 3)),
|
||||
}
|
||||
|
||||
# sample one item from the batch -1 will give the smalles item
|
||||
print(" | > Synthesising audio from the model...")
|
||||
inference_output = self.inference(
|
||||
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lenghts": batch["text_lengths"][-1].unsqueeze(0)}
|
||||
)
|
||||
figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3))
|
||||
|
||||
states = [p[1] for p in inference_output["input_parameters"][0]]
|
||||
transition_probability_synthesising = [p[2].cpu().numpy() for p in inference_output["output_parameters"][0]]
|
||||
|
||||
for i in range((len(transition_probability_synthesising) // 200) + 1):
|
||||
start = i * 200
|
||||
end = (i + 1) * 200
|
||||
figures[f"synthesised_transition_probabilities/{i}"] = plot_transition_probabilities_to_numpy(
|
||||
states[start:end], transition_probability_synthesising[start:end]
|
||||
)
|
||||
|
||||
audio = ap.inv_melspectrogram(inference_output["model_outputs"][0].T.cpu().numpy())
|
||||
return figures, {"audios": audio}
|
||||
|
||||
def train_log(
|
||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||
): # pylint: disable=unused-argument
|
||||
"""Log training progress."""
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.train_figures(steps, figures)
|
||||
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def eval_log(
|
||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int
|
||||
): # pylint: disable=unused-argument
|
||||
"""Compute and log evaluation metrics."""
|
||||
# Plot model parameters histograms
|
||||
if isinstance(logger, TensorboardLogger):
|
||||
# I don't know if any other loggers supports this
|
||||
for tag, value in self.named_parameters():
|
||||
tag = tag.replace(".", "/")
|
||||
logger.writer.add_histogram(tag, value.data.cpu().numpy(), steps)
|
||||
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def test_log(
|
||||
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
logger.test_audios(steps, outputs[1], self.ap.sample_rate)
|
||||
logger.test_figures(steps, outputs[0])
|
||||
|
||||
|
||||
class NLLLoss(nn.Module):
|
||||
"""Negative log likelihood loss."""
|
||||
|
||||
def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use
|
||||
"""Compute the loss.
|
||||
|
||||
Args:
|
||||
logits (Tensor): [B, T, D]
|
||||
|
||||
Returns:
|
||||
Tensor: [1]
|
||||
|
||||
"""
|
||||
return_dict = {}
|
||||
return_dict["loss"] = -log_prob.mean()
|
||||
return return_dict
|
|
@ -111,9 +111,6 @@ class Overflow(BaseTTS):
|
|||
self.register_buffer("mean", torch.tensor(0))
|
||||
self.register_buffer("std", torch.tensor(1))
|
||||
|
||||
# self.mean = nn.Parameter(torch.zeros(1), requires_grad=False)
|
||||
# self.std = nn.Parameter(torch.ones(1), requires_grad=False)
|
||||
|
||||
def update_mean_std(self, statistics_dict: Dict):
|
||||
self.mean.data = torch.tensor(statistics_dict["mean"])
|
||||
self.std.data = torch.tensor(statistics_dict["std"])
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
import os
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs.neuralhmm_tts_config import NeuralhmmTTSConfig
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.neuralhmm_tts import NeuralhmmTTS
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# init configs
|
||||
dataset_config = BaseDatasetConfig(
|
||||
formatter="ljspeech", meta_file_train="metadata.csv", path=os.path.join("data", "LJSpeech-1.1/")
|
||||
)
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=22050,
|
||||
do_trim_silence=True,
|
||||
trim_db=60.0,
|
||||
signal_norm=False,
|
||||
mel_fmin=0.0,
|
||||
mel_fmax=8000,
|
||||
spec_gain=1.0,
|
||||
log_func="np.log",
|
||||
ref_level_db=20,
|
||||
preemphasis=0.0,
|
||||
)
|
||||
|
||||
config = NeuralhmmTTSConfig( # This is the config that is saved for the future use
|
||||
run_name="neuralhmmtts_ljspeech",
|
||||
audio=audio_config,
|
||||
batch_size=32,
|
||||
eval_batch_size=16,
|
||||
num_loader_workers=4,
|
||||
num_eval_loader_workers=4,
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1000,
|
||||
text_cleaner="phoneme_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||
precompute_num_workers=8,
|
||||
mel_statistics_parameter_path=os.path.join(output_path, "lj_parameters.pt"),
|
||||
force_generate_statistics=False,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
mixed_precision=True,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
)
|
||||
|
||||
# INITIALIZE THE AUDIO PROCESSOR
|
||||
# Audio processor is used for feature extraction and audio I/O.
|
||||
# It mainly serves to the dataloader and the training loggers.
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
|
||||
# INITIALIZE THE TOKENIZER
|
||||
# Tokenizer is used to convert text to sequences of token IDs.
|
||||
# If characters are not defined in the config, default characters are passed to the config
|
||||
tokenizer, config = TTSTokenizer.init_from_config(config)
|
||||
|
||||
# LOAD DATA SAMPLES
|
||||
# Each sample is a list of ```[text, audio_file_path, speaker_name]```
|
||||
# You can define your custom sample loader returning the list of samples.
|
||||
# Or define your custom formatter and pass it to the `load_tts_samples`.
|
||||
# Check `TTS.tts.datasets.load_tts_samples` for more details.
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
dataset_config,
|
||||
eval_split=True,
|
||||
eval_split_max_size=config.eval_split_max_size,
|
||||
eval_split_size=config.eval_split_size,
|
||||
)
|
||||
|
||||
# INITIALIZE THE MODEL
|
||||
# Models take a config object and a speaker manager as input
|
||||
# Config defines the details of the model like the number of layers, the size of the embedding, etc.
|
||||
# Speaker manager is used by multi-speaker models.
|
||||
model = NeuralhmmTTS(config, ap, tokenizer)
|
||||
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
TrainerArgs(),
|
||||
config,
|
||||
output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
gpu=1,
|
||||
)
|
||||
trainer.fit()
|
|
@ -0,0 +1,92 @@
|
|||
import glob
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
from trainer import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.neuralhmm_tts_config import NeuralhmmTTSConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
parameter_path = os.path.join(get_tests_output_path(), "lj_parameters.pt")
|
||||
|
||||
torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path)
|
||||
|
||||
config = NeuralhmmTTSConfig(
|
||||
batch_size=3,
|
||||
eval_batch_size=3,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="phoneme_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
mel_statistics_parameter_path=parameter_path,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
test_sentences=[
|
||||
"Be a voice, not an echo.",
|
||||
],
|
||||
print_eval=True,
|
||||
max_sampling_time=50,
|
||||
)
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
config.save_json(config_path)
|
||||
|
||||
|
||||
# train the model for one epoch when mel parameters exists
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.formatter ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.test_delay_epochs 0 "
|
||||
)
|
||||
run_cli(command_train)
|
||||
|
||||
|
||||
# train the model for one epoch when mel parameters have to be computed from the dataset
|
||||
if os.path.exists(parameter_path):
|
||||
os.remove(parameter_path)
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.formatter ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.test_delay_epochs 0 "
|
||||
)
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# Inference using TTS API
|
||||
continue_config_path = os.path.join(continue_path, "config.json")
|
||||
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
|
||||
# Check integrity of the config
|
||||
with open(continue_config_path, "r", encoding="utf-8") as f:
|
||||
config_loaded = json.load(f)
|
||||
assert config_loaded["characters"] is not None
|
||||
assert config_loaded["output_path"] in continue_path
|
||||
assert config_loaded["test_delay_epochs"] == 0
|
||||
|
||||
# Load the model and run inference
|
||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
|
||||
run_cli(inference_command)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
Loading…
Reference in New Issue