Merge branch 'dev' into main

This commit is contained in:
Eren G??lge 2023-02-06 11:25:33 +01:00
commit 910a218652
53 changed files with 2088 additions and 106 deletions

View File

@ -3,7 +3,6 @@
----
### 📣 Clone your voice with a single click on [🐸Coqui.ai](https://app.coqui.ai/auth/signin)
### 📣 🐸Coqui Studio is launching soon!! Join our [waiting list](https://coqui.ai/)!!
----
@ -92,9 +91,11 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
- Align-TTS: [paper](https://arxiv.org/abs/2003.01950)
- FastPitch: [paper](https://arxiv.org/pdf/2006.06873.pdf)
- FastSpeech: [paper](https://arxiv.org/abs/1905.09263)
- FastSpeech2: [paper](https://arxiv.org/abs/2006.04558)
- SC-GlowTTS: [paper](https://arxiv.org/abs/2104.05557)
- Capacitron: [paper](https://arxiv.org/abs/1906.03402)
- OverFlow: [paper](https://arxiv.org/abs/2211.06892)
- Neural HMM TTS: [paper](https://arxiv.org/abs/2108.13320)
### End-to-End Models
- VITS: [paper](https://arxiv.org/pdf/2106.06103)
@ -190,6 +191,12 @@ tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.langu
tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False)
# Run TTS
tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)
# Example voice cloning with YourTTS in English, French and Portuguese:
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="output.wav")
tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav")
```
### Command line `tts`

View File

@ -4,7 +4,7 @@
"multi-dataset":{
"your_tts":{
"description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--multilingual--multi-dataset--your_tts.zip",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.10.1_models/tts_models--multilingual--multi-dataset--your_tts.zip",
"default_vocoder": null,
"commit": "e9a1953e",
"license": "CC BY-NC-ND 4.0",

View File

@ -1 +1,5 @@
0.10.1
<<<<<<< HEAD
0.10.1
=======
0.10.2
>>>>>>> dev

View File

@ -7,7 +7,16 @@ from TTS.utils.synthesizer import Synthesizer
class TTS:
"""TODO: Add voice conversion and Capacitron support."""
def __init__(self, model_name: str = None, progress_bar: bool = True, gpu=False):
def __init__(
self,
model_name: str = None,
model_path: str = None,
config_path: str = None,
vocoder_path: str = None,
vocoder_config_path: str = None,
progress_bar: bool = True,
gpu=False,
):
"""🐸TTS python interface that allows to load and use the released models.
Example with a multi-speaker model:
@ -20,8 +29,22 @@ class TTS:
>>> tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False)
>>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav")
Example loading a model from a path:
>>> tts = TTS(model_path="/path/to/checkpoint_100000.pth", config_path="/path/to/config.json", progress_bar=False, gpu=False)
>>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav")
Example voice cloning with YourTTS in English, French and Portuguese:
>>> tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
>>> tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="thisisit.wav")
>>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav")
>>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav")
Args:
model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None.
model_path (str, optional): Path to the model checkpoint. Defaults to None.
config_path (str, optional): Path to the model config. Defaults to None.
vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None.
vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None.
progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True.
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
@ -29,6 +52,10 @@ class TTS:
self.synthesizer = None
if model_name:
self.load_model_by_name(model_name, gpu)
if model_path:
self.load_model_by_path(
model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu
)
@property
def models(self):
@ -75,7 +102,17 @@ class TTS:
return model_path, config_path, vocoder_path, vocoder_config_path
def load_model_by_name(self, model_name: str, gpu: bool = False):
""" Load one of 🐸TTS models by name.
Args:
model_name (str): Model name to load. You can list models by ```tts.models```.
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
TODO: Add tests
"""
model_path, config_path, vocoder_path, vocoder_config_path = self.download_model_by_name(model_name)
# init synthesizer
# None values are fetch from the model
self.synthesizer = Synthesizer(
@ -90,8 +127,33 @@ class TTS:
use_cuda=gpu,
)
def _check_arguments(self, speaker: str = None, language: str = None):
if self.is_multi_speaker and speaker is None:
def load_model_by_path(
self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False
):
"""Load a model from a path.
Args:
model_path (str): Path to the model checkpoint.
config_path (str): Path to the model config.
vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None.
vocoder_config (str, optional): Path to the vocoder config. Defaults to None.
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
self.synthesizer = Synthesizer(
tts_checkpoint=model_path,
tts_config_path=config_path,
tts_speakers_file=None,
tts_languages_file=None,
vocoder_checkpoint=vocoder_path,
vocoder_config=vocoder_config,
encoder_checkpoint=None,
encoder_config=None,
use_cuda=gpu,
)
def _check_arguments(self, speaker: str = None, language: str = None, speaker_wav: str = None):
if self.is_multi_speaker and (speaker is None and speaker_wav is None):
raise ValueError("Model is multi-speaker but no speaker is provided.")
if self.is_multi_lingual and language is None:
raise ValueError("Model is multi-lingual but no language is provided.")
@ -100,7 +162,7 @@ class TTS:
if not self.is_multi_lingual and language is not None:
raise ValueError("Model is not multi-lingual but language is provided.")
def tts(self, text: str, speaker: str = None, language: str = None):
def tts(self, text: str, speaker: str = None, language: str = None, speaker_wav: str = None):
"""Convert text to speech.
Args:
@ -112,14 +174,17 @@ class TTS:
language (str, optional):
Language code for multi-lingual models. You can check whether loaded model is multi-lingual
`tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
speaker_wav (str, optional):
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
Defaults to None.
"""
self._check_arguments(speaker=speaker, language=language)
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav)
wav = self.synthesizer.tts(
text=text,
speaker_name=speaker,
language_name=language,
speaker_wav=None,
speaker_wav=speaker_wav,
reference_wav=None,
style_wav=None,
style_text=None,
@ -127,7 +192,14 @@ class TTS:
)
return wav
def tts_to_file(self, text: str, speaker: str = None, language: str = None, file_path: str = "output.wav"):
def tts_to_file(
self,
text: str,
speaker: str = None,
language: str = None,
speaker_wav: str = None,
file_path: str = "output.wav",
):
"""Convert text to speech.
Args:
@ -139,8 +211,11 @@ class TTS:
language (str, optional):
Language code for multi-lingual models. You can check whether loaded model is multi-lingual
`tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
speaker_wav (str, optional):
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
Defaults to None.
file_path (str, optional):
Output file path. Defaults to "output.wav".
"""
wav = self.tts(text=text, speaker=speaker, language=language)
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav)
self.synthesizer.save_wav(wav=wav, path=file_path)

View File

@ -2,8 +2,8 @@ import argparse
import glob
import os
from argparse import RawTextHelpFormatter
from distutils.dir_util import copy_tree
from multiprocessing import Pool
from shutil import copytree
import librosa
import soundfile as sf
@ -19,7 +19,7 @@ def resample_file(func_args):
def resample_files(input_dir, output_sr, output_dir=None, file_ext="wav", n_jobs=10):
if output_dir:
print("Recursively copying the input folder...")
copy_tree(input_dir, output_dir)
copytree(input_dir, output_dir)
input_dir = output_dir
print("Resampling the audio files...")

View File

@ -212,6 +212,9 @@ class BaseDatasetConfig(Coqpit):
language (str):
Language code of the dataset. If defined, it overrides `phoneme_language`. Defaults to `""`.
phonemizer (str):
Phonemizer used for that dataset's language. By default it uses `DEF_LANG_TO_PHONEMIZER`. Defaults to `""`.
meta_file_val (str):
Name of the dataset meta file that defines the instances used at validation.
@ -226,6 +229,7 @@ class BaseDatasetConfig(Coqpit):
meta_file_train: str = ""
ignored_speakers: List[str] = None
language: str = ""
phonemizer: str = ""
meta_file_val: str = ""
meta_file_attn_mask: str = ""

View File

@ -115,8 +115,13 @@ synthesizer = Synthesizer(
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and (
synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None
)
speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None)
use_multi_language = hasattr(synthesizer.tts_model, "num_languages") and (
synthesizer.tts_model.num_languages > 1 or synthesizer.tts_languages_file is not None
)
language_manager = getattr(synthesizer.tts_model, "language_manager", None)
# TODO: set this from SpeakerManager
use_gst = synthesizer.tts_config.get("use_gst", False)
app = Flask(__name__)
@ -147,7 +152,9 @@ def index():
"index.html",
show_details=args.show_details,
use_multi_speaker=use_multi_speaker,
use_multi_language=use_multi_language,
speaker_ids=speaker_manager.name_to_id if speaker_manager is not None else None,
language_ids=language_manager.name_to_id if language_manager is not None else None,
use_gst=use_gst,
)
@ -177,11 +184,13 @@ def tts():
with lock:
text = request.args.get("text")
speaker_idx = request.args.get("speaker_id", "")
language_idx = request.args.get("language_id", "")
style_wav = request.args.get("style_wav", "")
style_wav = style_wav_uri_to_dict(style_wav)
print(" > Model input: {}".format(text))
print(" > Speaker Idx: {}".format(speaker_idx))
wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav)
print(" > Language Idx: {}".format(language_idx))
wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
return send_file(out, mimetype="audio/wav")

View File

@ -65,7 +65,7 @@
</ul>
{%if use_gst%}
<input value='{"0": 0.1}' id="style_wav" placeholder="style wav (dict or path ot wav).." size=45
<input value='{"0": 0.1}' id="style_wav" placeholder="style wav (dict or path to wav).." size=45
type="text" name="style_wav">
{%endif%}
@ -81,6 +81,16 @@
</select><br /><br />
{%endif%}
{%if use_multi_language%}
Choose a language:
<select id="language_id" name=language_id method="GET" action="/">
{% for language_id in language_ids %}
<option value="{{language_id}}" SELECTED>{{language_id}}</option>"
{% endfor %}
</select><br /><br />
{%endif%}
{%if show_details%}
<button id="details-button" onclick="location.href = 'details'" name="model-details">Model
Details</button><br /><br />
@ -106,11 +116,12 @@
const text = q('#text').value
const speaker_id = getTextValue('#speaker_id')
const style_wav = getTextValue('#style_wav')
const language_id = getTextValue('#language_id')
if (text) {
q('#message').textContent = 'Synthesizing...'
q('#speak-button').disabled = true
q('#audio').hidden = true
synthesize(text, speaker_id, style_wav)
synthesize(text, speaker_id, style_wav, language_id)
}
e.preventDefault()
return false
@ -121,8 +132,8 @@
do_tts(e)
}
})
function synthesize(text, speaker_id = "", style_wav = "") {
fetch(`/api/tts?text=${encodeURIComponent(text)}&speaker_id=${encodeURIComponent(speaker_id)}&style_wav=${encodeURIComponent(style_wav)}`, { cache: 'no-cache' })
function synthesize(text, speaker_id = "", style_wav = "", language_id = "") {
fetch(`/api/tts?text=${encodeURIComponent(text)}&speaker_id=${encodeURIComponent(speaker_id)}&style_wav=${encodeURIComponent(style_wav)}&language_id=${encodeURIComponent(language_id)}`, { cache: 'no-cache' })
.then(function (res) {
if (!res.ok) throw Error(res.statusText)
return res.blob()

View File

@ -100,6 +100,13 @@ class FastPitchConfig(BaseTTSConfig):
max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
# dataset configs
compute_f0(bool):
Compute pitch. defaults to True
f0_cache_path(str):
pith cache path. defaults to None
"""
model: str = "fast_pitch"

View File

@ -0,0 +1,198 @@
from dataclasses import dataclass, field
from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.forward_tts import ForwardTTSArgs
@dataclass
class Fastspeech2Config(BaseTTSConfig):
"""Configure `ForwardTTS` as FastPitch model.
Example:
>>> from TTS.tts.configs.fastspeech2_config import FastSpeech2Config
>>> config = FastSpeech2Config()
Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
base_model (str):
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
model_args (Coqpit):
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
data_dep_init_steps (int):
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
for the rest. Defaults to 10.
speakers_file (str):
Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to
speaker names. Defaults to `None`.
use_speaker_embedding (bool):
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
in the multi-speaker mode. Defaults to False.
use_d_vector_file (bool):
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
d_vector_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None.
d_vector_dim (int):
Dimension of the external speaker embeddings. Defaults to 0.
optimizer (str):
Name of the model optimizer. Defaults to `Adam`.
optimizer_params (dict):
Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`.
lr_scheduler (str):
Name of the learning rate scheduler. Defaults to `Noam`.
lr_scheduler_params (dict):
Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`.
lr (float):
Initial learning rate. Defaults to `1e-3`.
grad_clip (float):
Gradient norm clipping value. Defaults to `5.0`.
spec_loss_type (str):
Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
duration_loss_type (str):
Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
use_ssim_loss (bool):
Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True.
wd (float):
Weight decay coefficient. Defaults to `1e-7`.
ssim_loss_alpha (float):
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
dur_loss_alpha (float):
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
spec_loss_alpha (float):
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.
pitch_loss_alpha (float):
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
energy_loss_alpha (float):
Weight for the energy predictor's loss. If set 0, disables the energy predictor. Defaults to 1.0.
binary_align_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_loss_warmup_epochs (float):
Number of epochs to gradually increase the binary loss impact. Defaults to 150.
min_seq_len (int):
Minimum input sequence length to be used at training.
max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
# dataset configs
compute_f0(bool):
Compute pitch. defaults to True
f0_cache_path(str):
pith cache path. defaults to None
# dataset configs
compute_energy(bool):
Compute energy. defaults to True
energy_cache_path(str):
energy cache path. defaults to None
"""
model: str = "fastspeech2"
base_model: str = "forward_tts"
# model specific params
model_args: ForwardTTSArgs = ForwardTTSArgs()
# multi-speaker settings
num_speakers: int = 0
speakers_file: str = None
use_speaker_embedding: bool = False
use_d_vector_file: bool = False
d_vector_file: str = False
d_vector_dim: int = 0
# optimizer parameters
optimizer: str = "Adam"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = "NoamLR"
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
lr: float = 1e-4
grad_clip: float = 5.0
# loss params
spec_loss_type: str = "mse"
duration_loss_type: str = "mse"
use_ssim_loss: bool = True
ssim_loss_alpha: float = 1.0
spec_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0
pitch_loss_alpha: float = 0.1
energy_loss_alpha: float = 0.1
dur_loss_alpha: float = 0.1
binary_align_loss_alpha: float = 0.1
binary_loss_warmup_epochs: int = 150
# overrides
min_seq_len: int = 13
max_seq_len: int = 200
r: int = 1 # DO NOT CHANGE
# dataset configs
compute_f0: bool = True
f0_cache_path: str = None
# dataset configs
compute_energy: bool = True
energy_cache_path: str = None
# testing
test_sentences: List[str] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
]
)
def __post_init__(self):
# Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
if self.num_speakers > 0:
self.model_args.num_speakers = self.num_speakers
# speaker embedding settings
if self.use_speaker_embedding:
self.model_args.use_speaker_embedding = True
if self.speakers_file:
self.model_args.speakers_file = self.speakers_file
# d-vector settings
if self.use_d_vector_file:
self.model_args.use_d_vector_file = True
if self.d_vector_dim is not None and self.d_vector_dim > 0:
self.model_args.d_vector_dim = self.d_vector_dim
if self.d_vector_file:
self.model_args.d_vector_file = self.d_vector_file

View File

@ -0,0 +1,170 @@
from dataclasses import dataclass, field
from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
@dataclass
class NeuralhmmTTSConfig(BaseTTSConfig):
"""
Define parameters for Neural HMM TTS model.
Example:
>>> from TTS.tts.configs.overflow_config import OverflowConfig
>>> config = OverflowConfig()
Args:
model (str):
Model name used to select the right model class to initilize. Defaults to `Overflow`.
run_eval_steps (int):
Run evalulation epoch after N steps. If None, waits until training epoch is completed. Defaults to None.
save_step (int):
Save local checkpoint every save_step steps. Defaults to 500.
plot_step (int):
Plot training stats on the logger every plot_step steps. Defaults to 1.
model_param_stats (bool):
Log model parameters stats on the logger dashboard. Defaults to False.
force_generate_statistics (bool):
Force generate mel normalization statistics. Defaults to False.
mel_statistics_parameter_path (str):
Path to the mel normalization statistics.If the model doesn't finds a file there it will generate statistics.
Defaults to None.
num_chars (int):
Number of characters used by the model. It must be defined before initializing the model. Defaults to None.
state_per_phone (int):
Generates N states per phone. Similar, to `add_blank` parameter in GlowTTS but in Overflow it is upsampled by model's encoder. Defaults to 2.
encoder_in_out_features (int):
Channels of encoder input and character embedding tensors. Defaults to 512.
encoder_n_convolutions (int):
Number of convolution layers in the encoder. Defaults to 3.
out_channels (int):
Channels of the final model output. It must match the spectragram size. Defaults to 80.
ar_order (int):
Autoregressive order of the model. Defaults to 1. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio.
sampling_temp (float):
Variation added to the sample from the latent space of neural HMM. Defaults to 0.334.
deterministic_transition (bool):
deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True.
duration_threshold (float):
Threshold for duration quantiles. Defaults to 0.55. Tune this to change the speaking rate of the synthesis, where lower values defines a slower speaking rate and higher values defines a faster speaking rate.
use_grad_checkpointing (bool):
Use gradient checkpointing to save memory. In a multi-GPU setting currently pytorch does not supports gradient checkpoint inside a loop so we will have to turn it off then.Adjust depending on whatever get more batch size either by using a single GPU or multi-GPU. Defaults to True.
max_sampling_time (int):
Maximum sampling time while synthesising latents from neural HMM. Defaults to 1000.
prenet_type (str):
`original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the
Prenet. Defaults to `original`.
prenet_dim (int):
Dimension of the Prenet. Defaults to 256.
prenet_n_layers (int):
Number of layers in the Prenet. Defaults to 2.
prenet_dropout (float):
Dropout rate of the Prenet. Defaults to 0.5.
prenet_dropout_at_inference (bool):
Use dropout at inference time. Defaults to False.
memory_rnn_dim (int):
Dimension of the memory LSTM to process the prenet output. Defaults to 1024.
outputnet_size (list[int]):
Size of the output network inside the neural HMM. Defaults to [1024].
flat_start_params (dict):
Parameters for the flat start initialization of the neural HMM. Defaults to `{"mean": 0.0, "std": 1.0, "transition_p": 0.14}`.
It will be recomputed when you pass the dataset.
std_floor (float):
Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint. Defaults to 0.01.
It is called `variance flooring` in standard HMM literature.
optimizer (str):
Optimizer to use for training. Defaults to `adam`.
optimizer_params (dict):
Parameters for the optimizer. Defaults to `{"weight_decay": 1e-6}`.
grad_clip (float):
Gradient clipping threshold. Defaults to 40_000.
lr (float):
Learning rate. Defaults to 1e-3.
lr_scheduler (str):
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
`TTS.utils.training`. Defaults to `None`.
min_seq_len (int):
Minimum input sequence length to be used at training.
max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
"""
model: str = "NeuralHMM_TTS"
# Training and Checkpoint configs
run_eval_steps: int = 100
save_step: int = 500
plot_step: int = 1
model_param_stats: bool = False
# data parameters
force_generate_statistics: bool = False
mel_statistics_parameter_path: str = None
# Encoder parameters
num_chars: int = None
state_per_phone: int = 2
encoder_in_out_features: int = 512
encoder_n_convolutions: int = 3
# HMM parameters
out_channels: int = 80
ar_order: int = 1
sampling_temp: float = 0
deterministic_transition: bool = True
duration_threshold: float = 0.43
use_grad_checkpointing: bool = True
max_sampling_time: int = 1000
## Prenet parameters
prenet_type: str = "original"
prenet_dim: int = 256
prenet_n_layers: int = 2
prenet_dropout: float = 0.5
prenet_dropout_at_inference: bool = True
memory_rnn_dim: int = 1024
## Outputnet parameters
outputnet_size: List[int] = field(default_factory=lambda: [1024])
flat_start_params: dict = field(default_factory=lambda: {"mean": 0.0, "std": 1.0, "transition_p": 0.14})
std_floor: float = 0.001
# optimizer parameters
optimizer: str = "Adam"
optimizer_params: dict = field(default_factory=lambda: {"weight_decay": 1e-6})
grad_clip: float = 40000.0
lr: float = 1e-3
lr_scheduler: str = None
# overrides
min_text_len: int = 10
max_text_len: int = 500
min_audio_len: int = 512
# testing
test_sentences: List[str] = field(
default_factory=lambda: [
"Be a voice, not an echo.",
]
)
# Extra needed config
r: int = 1
use_d_vector_file: bool = False
use_speaker_embedding: bool = False
def check_values(self):
"""Validate the hyperparameters.
Raises:
AssertionError: when the parameters network is not defined
AssertionError: transition probability is not between 0 and 1
"""
assert self.ar_order > 0, "AR order must be greater than 0 it is an autoregressive model."
assert (
len(self.outputnet_size) >= 1
), f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}"
assert (
0 < self.flat_start_params["transition_p"] < 1
), f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}"

View File

@ -217,6 +217,9 @@ class BaseTTSConfig(BaseTrainingConfig):
compute_f0 (int):
(Not in use yet).
compute_energy (int):
(Not in use yet).
compute_linear_spec (bool):
If True data loader computes and returns linear spectrograms alongside the other data.
@ -312,6 +315,7 @@ class BaseTTSConfig(BaseTrainingConfig):
min_text_len: int = 1
max_text_len: int = float("inf")
compute_f0: bool = False
compute_energy: bool = False
compute_linear_spec: bool = False
precompute_num_workers: int = 0
use_noise_augment: bool = False

View File

@ -167,7 +167,7 @@ class VitsConfig(BaseTTSConfig):
# use d-vectors
use_d_vector_file: bool = False
d_vector_file: str = None
d_vector_file: List[str] = None
d_vector_dim: int = None
def __post_init__(self):

View File

@ -11,6 +11,7 @@ from torch.utils.data import Dataset
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy
# to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
@ -50,7 +51,9 @@ class TTSDataset(Dataset):
samples: List[Dict] = None,
tokenizer: "TTSTokenizer" = None,
compute_f0: bool = False,
compute_energy: bool = False,
f0_cache_path: str = None,
energy_cache_path: str = None,
return_wav: bool = False,
batch_group_size: int = 0,
min_text_len: int = 0,
@ -84,8 +87,12 @@ class TTSDataset(Dataset):
compute_f0 (bool): compute f0 if True. Defaults to False.
compute_energy (bool): compute energy if True. Defaults to False.
f0_cache_path (str): Path to store f0 cache. Defaults to None.
energy_cache_path (str): Path to store energy cache. Defaults to None.
return_wav (bool): Return the waveform of the sample. Defaults to False.
batch_group_size (int): Range of batch randomization after sorting
@ -128,7 +135,9 @@ class TTSDataset(Dataset):
self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav
self.compute_f0 = compute_f0
self.compute_energy = compute_energy
self.f0_cache_path = f0_cache_path
self.energy_cache_path = energy_cache_path
self.min_audio_len = min_audio_len
self.max_audio_len = max_audio_len
self.min_text_len = min_text_len
@ -155,7 +164,10 @@ class TTSDataset(Dataset):
self.f0_dataset = F0Dataset(
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
)
if compute_energy:
self.energy_dataset = EnergyDataset(
self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers
)
if self.verbose:
self.print_logs()
@ -211,6 +223,12 @@ class TTSDataset(Dataset):
assert item["audio_unique_name"] == out_dict["audio_unique_name"]
return out_dict
def get_energy(self, idx):
out_dict = self.energy_dataset[idx]
item = self.samples[idx]
assert item["audio_unique_name"] == out_dict["audio_unique_name"]
return out_dict
@staticmethod
def get_attn_mask(attn_file):
return np.load(attn_file)
@ -252,12 +270,16 @@ class TTSDataset(Dataset):
f0 = None
if self.compute_f0:
f0 = self.get_f0(idx)["f0"]
energy = None
if self.compute_energy:
energy = self.get_energy(idx)["energy"]
sample = {
"raw_text": raw_text,
"token_ids": token_ids,
"wav": wav,
"pitch": f0,
"energy": energy,
"attn": attn,
"item_idx": item["audio_file"],
"speaker_name": item["speaker_name"],
@ -490,7 +512,13 @@ class TTSDataset(Dataset):
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
else:
pitch = None
# format energy
if self.compute_energy:
energy = prepare_data(batch["energy"])
assert mel.shape[1] == energy.shape[1], f"[!] {mel.shape} vs {energy.shape}"
energy = torch.FloatTensor(energy)[:, None, :].contiguous() # B x 1 xT
else:
energy = None
# format attention masks
attns = None
if batch["attn"][0] is not None:
@ -519,6 +547,7 @@ class TTSDataset(Dataset):
"waveform": wav_padded,
"raw_text": batch["raw_text"],
"pitch": pitch,
"energy": energy,
"language_ids": language_ids,
"audio_unique_names": batch["audio_unique_name"],
}
@ -569,14 +598,14 @@ class PhonemeDataset(Dataset):
def __getitem__(self, index):
item = self.samples[index]
ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"])
ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"], item["language"])
ph_hat = self.tokenizer.ids_to_text(ids)
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
def __len__(self):
return len(self.samples)
def compute_or_load(self, file_name, text):
def compute_or_load(self, file_name, text, language):
"""Compute phonemes for the given text.
If the phonemes are already cached, load them from cache.
@ -586,7 +615,7 @@ class PhonemeDataset(Dataset):
try:
ids = np.load(cache_path)
except FileNotFoundError:
ids = self.tokenizer.text_to_ids(text)
ids = self.tokenizer.text_to_ids(text, language=language)
np.save(cache_path, ids)
return ids
@ -777,3 +806,155 @@ class F0Dataset:
print("\n")
print(f"{indent}> F0Dataset ")
print(f"{indent}| > Number of instances : {len(self.samples)}")
class EnergyDataset:
"""Energy Dataset for computing Energy from wav files in CPU
Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It
also computes the mean and std of Energy values if `normalize_Energy` is True.
Args:
samples (Union[List[List], List[Dict]]):
List of samples. Each sample is a list or a dict.
ap (AudioProcessor):
AudioProcessor to compute Energy from wav files.
cache_path (str):
Path to cache Energy values. If `cache_path` is already present or None, it skips the pre-computation.
Defaults to None.
precompute_num_workers (int):
Number of workers used for pre-computing the Energy values. Defaults to 0.
normalize_Energy (bool):
Whether to normalize Energy values by mean and std. Defaults to True.
"""
def __init__(
self,
samples: Union[List[List], List[Dict]],
ap: "AudioProcessor",
verbose=False,
cache_path: str = None,
precompute_num_workers=0,
normalize_energy=True,
):
self.samples = samples
self.ap = ap
self.verbose = verbose
self.cache_path = cache_path
self.normalize_energy = normalize_energy
self.pad_id = 0.0
self.mean = None
self.std = None
if cache_path is not None and not os.path.exists(cache_path):
os.makedirs(cache_path)
self.precompute(precompute_num_workers)
if normalize_energy:
self.load_stats(cache_path)
def __getitem__(self, idx):
item = self.samples[idx]
energy = self.compute_or_load(item["audio_file"])
if self.normalize_energy:
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
energy = self.normalize(energy)
return {"audio_file": item["audio_file"], "energy": energy}
def __len__(self):
return len(self.samples)
def precompute(self, num_workers=0):
print("[*] Pre-computing energys...")
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
# we do not normalize at preproessing
normalize_energy = self.normalize_energy
self.normalize_energy = False
dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
)
computed_data = []
for batch in dataloder:
energy = batch["energy"]
computed_data.append(e for e in energy)
pbar.update(batch_size)
self.normalize_energy = normalize_energy
if self.normalize_energy:
computed_data = [tensor for batch in computed_data for tensor in batch] # flatten
energy_mean, energy_std = self.compute_pitch_stats(computed_data)
energy_stats = {"mean": energy_mean, "std": energy_std}
np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True)
def get_pad_id(self):
return self.pad_id
@staticmethod
def create_energy_file_path(wav_file, cache_path):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
energy_file = os.path.join(cache_path, file_name + "_energy.npy")
return energy_file
@staticmethod
def _compute_and_save_energy(ap, wav_file, energy_file=None):
wav = ap.load_wav(wav_file)
energy = calculate_energy(wav)
if energy_file:
np.save(energy_file, energy)
return energy
@staticmethod
def compute_energy_stats(energy_vecs):
nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in energy_vecs])
mean, std = np.mean(nonzeros), np.std(nonzeros)
return mean, std
def load_stats(self, cache_path):
stats_path = os.path.join(cache_path, "energy_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32)
self.std = stats["std"].astype(np.float32)
def normalize(self, energy):
zero_idxs = np.where(energy == 0.0)[0]
energy = energy - self.mean
energy = energy / self.std
energy[zero_idxs] = 0.0
return energy
def denormalize(self, energy):
zero_idxs = np.where(energy == 0.0)[0]
energy *= self.std
energy += self.mean
energy[zero_idxs] = 0.0
return energy
def compute_or_load(self, wav_file):
"""
compute energy and return a numpy array of energy values
"""
energy_file = self.create_Energy_file_path(wav_file, self.cache_path)
if not os.path.exists(energy_file):
energy = self._compute_and_save_energy(self.ap, wav_file, energy_file)
else:
energy = np.load(energy_file)
return energy.astype(np.float32)
def collate_fn(self, batch):
audio_file = [item["audio_file"] for item in batch]
energys = [item["energy"] for item in batch]
energy_lens = [len(item["energy"]) for item in batch]
energy_lens_max = max(energy_lens)
energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id())
for i, energy_len in enumerate(energy_lens):
energys_torch[i, :energy_len] = torch.LongTensor(energys[i])
return {"audio_file": audio_file, "energy": energys_torch, "energy_lens": energy_lens}
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> energyDataset ")
print(f"{indent}| > Number of instances : {len(self.samples)}")

View File

@ -1,6 +1,5 @@
from distutils.version import LooseVersion
import torch
from packaging.version import Version
from torch import nn
from torch.nn import functional as F
@ -91,7 +90,7 @@ class InvConvNear(nn.Module):
self.no_jacobian = no_jacobian
self.weight_inv = None
if LooseVersion(torch.__version__) < LooseVersion("1.9"):
if Version(torch.__version__) < Version("1.9"):
w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
else:
w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0]

View File

@ -801,6 +801,10 @@ class ForwardTTSLoss(nn.Module):
self.pitch_loss = MSELossMasked(False)
self.pitch_loss_alpha = c.pitch_loss_alpha
if c.model_args.use_energy:
self.energy_loss = MSELossMasked(False)
self.energy_loss_alpha = c.energy_loss_alpha
if c.use_ssim_loss:
self.ssim = SSIMLoss() if c.use_ssim_loss else None
self.ssim_loss_alpha = c.ssim_loss_alpha
@ -826,6 +830,8 @@ class ForwardTTSLoss(nn.Module):
dur_target,
pitch_output,
pitch_target,
energy_output,
energy_target,
input_lens,
alignment_logprob=None,
alignment_hard=None,
@ -855,6 +861,11 @@ class ForwardTTSLoss(nn.Module):
loss = loss + self.pitch_loss_alpha * pitch_loss
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
if hasattr(self, "energy_loss") and self.energy_loss_alpha > 0:
energy_loss = self.energy_loss(energy_output.transpose(1, 2), energy_target.transpose(1, 2), input_lens)
loss = loss + self.energy_loss_alpha * energy_loss
return_dict["loss_energy"] = self.energy_loss_alpha * energy_loss
if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0:
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
loss = loss + self.aligner_loss_alpha * aligner_loss

View File

@ -30,7 +30,7 @@ def validate_numpy_array(value: Any):
return value
def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder):
def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder=None):
"""Get the most probable state means from the log_alpha_scaled.
Args:
@ -38,16 +38,21 @@ def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder):
- Shape: :math:`(T, N)`
means (torch.Tensor): Means of the states.
- Shape: :math:`(N, T, D_out)`
decoder (torch.nn.Module): Decoder module to decode the latent to melspectrogram
decoder (torch.nn.Module): Decoder module to decode the latent to melspectrogram. Defaults to None.
"""
max_state_numbers = torch.max(log_alpha_scaled, dim=1)[1]
max_len = means.shape[0]
n_mel_channels = means.shape[2]
max_state_numbers = max_state_numbers.unsqueeze(1).unsqueeze(1).expand(max_len, 1, n_mel_channels)
means = torch.gather(means, 1, max_state_numbers).squeeze(1).to(log_alpha_scaled.dtype)
mel = (
decoder(means.T.unsqueeze(0), torch.tensor([means.shape[0]], device=means.device), reverse=True)[0].squeeze(0).T
)
if decoder is not None:
mel = (
decoder(means.T.unsqueeze(0), torch.tensor([means.shape[0]], device=means.device), reverse=True)[0]
.squeeze(0)
.T
)
else:
mel = means
return mel

View File

@ -345,7 +345,7 @@ class BaseTTS(BaseTrainerModel):
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=config.shuffle if sampler is not None else False, # if there is no other sampler
shuffle=config.shuffle if sampler is None else False, # if there is no other sampler
collate_fn=dataset.collate_fn,
drop_last=config.drop_last, # setting this False might cause issues in AMP training.
sampler=sampler,

View File

@ -15,7 +15,7 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
from TTS.utils.io import load_fsspec
@ -42,6 +42,9 @@ class ForwardTTSArgs(Coqpit):
use_pitch (bool):
Use pitch predictor to learn the pitch. Defaults to True.
use_energy (bool):
Use energy predictor to learn the energy. Defaults to True.
duration_predictor_hidden_channels (int):
Number of hidden channels in the duration predictor. Defaults to 256.
@ -63,6 +66,18 @@ class ForwardTTSArgs(Coqpit):
pitch_embedding_kernel_size (int):
Kernel size of the projection layer in the pitch predictor. Defaults to 3.
energy_predictor_hidden_channels (int):
Number of hidden channels in the energy predictor. Defaults to 256.
energy_predictor_dropout_p (float):
Dropout rate for the energy predictor. Defaults to 0.1.
energy_predictor_kernel_size (int):
Kernel size of conv layers in the energy predictor. Defaults to 3.
energy_embedding_kernel_size (int):
Kernel size of the projection layer in the energy predictor. Defaults to 3.
positional_encoding (bool):
Whether to use positional encoding. Defaults to True.
@ -114,14 +129,25 @@ class ForwardTTSArgs(Coqpit):
out_channels: int = 80
hidden_channels: int = 384
use_aligner: bool = True
# pitch params
use_pitch: bool = True
pitch_predictor_hidden_channels: int = 256
pitch_predictor_kernel_size: int = 3
pitch_predictor_dropout_p: float = 0.1
pitch_embedding_kernel_size: int = 3
# energy params
use_energy: bool = False
energy_predictor_hidden_channels: int = 256
energy_predictor_kernel_size: int = 3
energy_predictor_dropout_p: float = 0.1
energy_embedding_kernel_size: int = 3
# duration params
duration_predictor_hidden_channels: int = 256
duration_predictor_kernel_size: int = 3
duration_predictor_dropout_p: float = 0.1
positional_encoding: bool = True
poisitonal_encoding_use_scale: bool = True
length_scale: int = 1
@ -158,7 +184,7 @@ class ForwardTTS(BaseTTS):
- FastPitch
- SpeedySpeech
- FastSpeech
- TODO: FastSpeech2 (requires average speech energy predictor)
- FastSpeech2 (requires average speech energy predictor)
Args:
config (Coqpit): Model coqpit class.
@ -187,6 +213,7 @@ class ForwardTTS(BaseTTS):
self.max_duration = self.args.max_duration
self.use_aligner = self.args.use_aligner
self.use_pitch = self.args.use_pitch
self.use_energy = self.args.use_energy
self.binary_loss_weight = 0.0
self.length_scale = (
@ -234,6 +261,20 @@ class ForwardTTS(BaseTTS):
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
)
if self.args.use_energy:
self.energy_predictor = DurationPredictor(
self.args.hidden_channels + self.embedded_speaker_dim,
self.args.energy_predictor_hidden_channels,
self.args.energy_predictor_kernel_size,
self.args.energy_predictor_dropout_p,
)
self.energy_emb = nn.Conv1d(
1,
self.args.hidden_channels,
kernel_size=self.args.energy_embedding_kernel_size,
padding=int((self.args.energy_embedding_kernel_size - 1) / 2),
)
if self.args.use_aligner:
self.aligner = AlignmentNetwork(
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
@ -440,6 +481,42 @@ class ForwardTTS(BaseTTS):
o_pitch_emb = self.pitch_emb(o_pitch)
return o_pitch_emb, o_pitch
def _forward_energy_predictor(
self,
o_en: torch.FloatTensor,
x_mask: torch.IntTensor,
energy: torch.FloatTensor = None,
dr: torch.IntTensor = None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Energy predictor forward pass.
1. Predict energy from encoder outputs.
2. In training - Compute average pitch values for each input character from the ground truth pitch values.
3. Embed average energy values.
Args:
o_en (torch.FloatTensor): Encoder output.
x_mask (torch.IntTensor): Input sequence mask.
energy (torch.FloatTensor, optional): Ground truth energy values. Defaults to None.
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Energy embedding, energy prediction.
Shapes:
- o_en: :math:`(B, C, T_{en})`
- x_mask: :math:`(B, 1, T_{en})`
- pitch: :math:`(B, 1, T_{de})`
- dr: :math:`(B, T_{en})`
"""
o_energy = self.energy_predictor(o_en, x_mask)
if energy is not None:
avg_energy = average_over_durations(energy, dr)
o_energy_emb = self.energy_emb(avg_energy)
return o_energy_emb, o_energy, avg_energy
o_energy_emb = self.energy_emb(o_energy)
return o_energy_emb, o_energy
def _forward_aligner(
self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor
) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
@ -502,6 +579,7 @@ class ForwardTTS(BaseTTS):
y: torch.FloatTensor = None,
dr: torch.IntTensor = None,
pitch: torch.FloatTensor = None,
energy: torch.FloatTensor = None,
aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument
) -> Dict:
"""Model's forward pass.
@ -513,6 +591,7 @@ class ForwardTTS(BaseTTS):
y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None.
dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None.
pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None.
energy (torch.FloatTensor): energy values for each spectrogram frame. Only used when the energy predictor is on. Defaults to None.
aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`.
Shapes:
@ -556,6 +635,12 @@ class ForwardTTS(BaseTTS):
if self.args.use_pitch:
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr)
o_en = o_en + o_pitch_emb
# energy predictor pass
o_energy = None
avg_energy = None
if self.args.use_energy:
o_energy_emb, o_energy, avg_energy = self._forward_energy_predictor(o_en, x_mask, energy, dr)
o_en = o_en + o_energy_emb
# decoder pass
o_de, attn = self._forward_decoder(
o_en, dr, x_mask, y_lengths, g=None
@ -567,6 +652,8 @@ class ForwardTTS(BaseTTS):
"attn_durations": o_attn, # for visualization [B, T_en, T_de']
"pitch_avg": o_pitch,
"pitch_avg_gt": avg_pitch,
"energy_avg": o_energy,
"energy_avg_gt": avg_energy,
"alignments": attn, # [B, T_de, T_en]
"alignment_soft": alignment_soft,
"alignment_mas": alignment_mas,
@ -604,12 +691,18 @@ class ForwardTTS(BaseTTS):
if self.args.use_pitch:
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask)
o_en = o_en + o_pitch_emb
# energy predictor pass
o_energy = None
if self.args.use_energy:
o_energy_emb, o_energy = self._forward_energy_predictor(o_en, x_mask)
o_en = o_en + o_energy_emb
# decoder pass
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None)
outputs = {
"model_outputs": o_de,
"alignments": attn,
"pitch": o_pitch,
"energy": o_energy,
"durations_log": o_dr_log,
}
return outputs
@ -620,6 +713,7 @@ class ForwardTTS(BaseTTS):
mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"]
pitch = batch["pitch"] if self.args.use_pitch else None
energy = batch["energy"] if self.args.use_energy else None
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
durations = batch["durations"]
@ -627,7 +721,14 @@ class ForwardTTS(BaseTTS):
# forward pass
outputs = self.forward(
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input
text_input,
text_lengths,
mel_lengths,
y=mel_input,
dr=durations,
pitch=pitch,
energy=energy,
aux_input=aux_input,
)
# use aligner's output as the duration target
if self.use_aligner:
@ -643,6 +744,8 @@ class ForwardTTS(BaseTTS):
dur_target=durations,
pitch_output=outputs["pitch_avg"] if self.use_pitch else None,
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
energy_output=outputs["energy_avg"] if self.use_energy else None,
energy_target=outputs["energy_avg_gt"] if self.use_energy else None,
input_lens=text_lengths,
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
alignment_soft=outputs["alignment_soft"],
@ -683,6 +786,17 @@ class ForwardTTS(BaseTTS):
}
figures.update(pitch_figures)
# plot energy figures
if self.args.use_energy:
energy_avg = abs(outputs["energy_avg_gt"][0, 0].data.cpu().numpy())
energy_avg_hat = abs(outputs["energy_avg"][0, 0].data.cpu().numpy())
chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy())
energy_figures = {
"energy_ground_truth": plot_avg_energy(energy_avg, chars, output_fig=False),
"energy_avg_predicted": plot_avg_energy(energy_avg_hat, chars, output_fig=False),
}
figures.update(energy_figures)
# plot the attention mask computed from the predicted durations
if "attn_durations" in outputs:
alignments_hat = outputs["attn_durations"][0].data.cpu().numpy()

View File

@ -0,0 +1,384 @@
import os
from typing import Dict, List, Union
import torch
from coqpit import Coqpit
from torch import nn
from trainer.logging.tensorboard_logger import TensorboardLogger
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
from TTS.tts.layers.overflow.neural_hmm import NeuralHMM
from TTS.tts.layers.overflow.plotting_utils import (
get_spec_from_most_probable_state,
plot_transition_probabilities_to_numpy,
)
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
class NeuralhmmTTS(BaseTTS):
"""Neural HMM TTS model.
Paper::
https://arxiv.org/abs/2108.13320
Paper abstract::
Neural sequence-to-sequence TTS has achieved significantly better output quality
than statistical speech synthesis using HMMs.However, neural TTS is generally not probabilistic
and uses non-monotonic attention. Attention failures increase training time and can make
synthesis babble incoherently. This paper describes how the old and new paradigms can be
combined to obtain the advantages of both worlds, by replacing attention in neural TTS with
an autoregressive left-right no-skip hidden Markov model defined by a neural network.
Based on this proposal, we modify Tacotron 2 to obtain an HMM-based neural TTS model with
monotonic alignment, trained to maximise the full sequence likelihood without approximation.
We also describe how to combine ideas from classical and contemporary TTS for best results.
The resulting example system is smaller and simpler than Tacotron 2, and learns to speak with
fewer iterations and less data, whilst achieving comparable naturalness prior to the post-net.
Our approach also allows easy control over speaking rate. Audio examples and code
are available at https://shivammehta25.github.io/Neural-HMM/ .
Note:
- This is a parameter efficient version of OverFlow (15.3M vs 28.6M). Since it has half the
number of parameters as OverFlow the synthesis output quality is suboptimal (but comparable to Tacotron2
without Postnet), but it learns to speak with even lesser amount of data and is still significantly faster
than other attention-based methods.
- Neural HMMs uses flat start initialization i.e it computes the means and std and transition probabilities
of the dataset and uses them to initialize the model. This benefits the model and helps with faster learning
If you change the dataset or want to regenerate the parameters change the `force_generate_statistics` and
`mel_statistics_parameter_path` accordingly.
- To enable multi-GPU training, set the `use_grad_checkpointing=False` in config.
This will significantly increase the memory usage. This is because to compute
the actual data likelihood (not an approximation using MAS/Viterbi) we must use
all the states at the previous time step during the forward pass to decide the
probability distribution at the current step i.e the difference between the forward
algorithm and viterbi approximation.
Check :class:`TTS.tts.configs.neuralhmm_tts_config.NeuralhmmTTSConfig` for class arguments.
"""
def __init__(
self,
config: "NeuralhmmTTSConfig",
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager)
# pass all config fields to `self`
# for fewer code change
self.config = config
for key in config:
setattr(self, key, config[key])
self.encoder = Encoder(config.num_chars, config.state_per_phone, config.encoder_in_out_features)
self.neural_hmm = NeuralHMM(
frame_channels=self.out_channels,
ar_order=self.ar_order,
deterministic_transition=self.deterministic_transition,
encoder_dim=self.encoder_in_out_features,
prenet_type=self.prenet_type,
prenet_dim=self.prenet_dim,
prenet_n_layers=self.prenet_n_layers,
prenet_dropout=self.prenet_dropout,
prenet_dropout_at_inference=self.prenet_dropout_at_inference,
memory_rnn_dim=self.memory_rnn_dim,
outputnet_size=self.outputnet_size,
flat_start_params=self.flat_start_params,
std_floor=self.std_floor,
use_grad_checkpointing=self.use_grad_checkpointing,
)
self.register_buffer("mean", torch.tensor(0))
self.register_buffer("std", torch.tensor(1))
def update_mean_std(self, statistics_dict: Dict):
self.mean.data = torch.tensor(statistics_dict["mean"])
self.std.data = torch.tensor(statistics_dict["std"])
def preprocess_batch(self, text, text_len, mels, mel_len):
if self.mean.item() == 0 or self.std.item() == 1:
statistics_dict = torch.load(self.mel_statistics_parameter_path)
self.update_mean_std(statistics_dict)
mels = self.normalize(mels)
return text, text_len, mels, mel_len
def normalize(self, x):
return x.sub(self.mean).div(self.std)
def inverse_normalize(self, x):
return x.mul(self.std).add(self.mean)
def forward(self, text, text_len, mels, mel_len):
"""
Forward pass for training and computing the log likelihood of a given batch.
Shapes:
Shapes:
text: :math:`[B, T_in]`
text_len: :math:`[B]`
mels: :math:`[B, T_out, C]`
mel_len: :math:`[B]`
"""
text, text_len, mels, mel_len = self.preprocess_batch(text, text_len, mels, mel_len)
encoder_outputs, encoder_output_len = self.encoder(text, text_len)
log_probs, fwd_alignments, transition_vectors, means = self.neural_hmm(
encoder_outputs, encoder_output_len, mels.transpose(1, 2), mel_len
)
outputs = {
"log_probs": log_probs,
"alignments": fwd_alignments,
"transition_vectors": transition_vectors,
"means": means,
}
return outputs
@staticmethod
def _training_stats(batch):
stats = {}
stats["avg_text_length"] = batch["text_lengths"].float().mean()
stats["avg_spec_length"] = batch["mel_lengths"].float().mean()
stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean()
stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean()
return stats
def train_step(self, batch: dict, criterion: nn.Module):
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"]
outputs = self.forward(
text=text_input,
text_len=text_lengths,
mels=mel_input,
mel_len=mel_lengths,
)
loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum()))
# for printing useful statistics on terminal
loss_dict.update(self._training_stats(batch))
return outputs, loss_dict
def eval_step(self, batch: Dict, criterion: nn.Module):
return self.train_step(batch, criterion)
def _format_aux_input(self, aux_input: Dict, default_input_dict):
"""Set missing fields to their default value.
Args:
aux_inputs (Dict): Dictionary containing the auxiliary inputs.
"""
default_input_dict.update(
{
"sampling_temp": self.sampling_temp,
"max_sampling_time": self.max_sampling_time,
"duration_threshold": self.duration_threshold,
}
)
if aux_input:
return format_aux_input(aux_input, default_input_dict)
return None
@torch.no_grad()
def inference(
self,
text: torch.Tensor,
aux_input={"x_lengths": None, "sampling_temp": None, "max_sampling_time": None, "duration_threshold": None},
): # pylint: disable=dangerous-default-value
"""Sampling from the model
Args:
text (torch.Tensor): :math:`[B, T_in]`
aux_inputs (_type_, optional): _description_. Defaults to None.
Returns:
outputs: Dictionary containing the following
- mel (torch.Tensor): :math:`[B, T_out, C]`
- hmm_outputs_len (torch.Tensor): :math:`[B]`
- state_travelled (List[List[int]]): List of lists containing the state travelled for each sample in the batch.
- input_parameters (list[torch.FloatTensor]): Input parameters to the neural HMM.
- output_parameters (list[torch.FloatTensor]): Output parameters to the neural HMM.
"""
default_input_dict = {
"x_lengths": torch.sum(text != 0, dim=1),
}
aux_input = self._format_aux_input(aux_input, default_input_dict)
encoder_outputs, encoder_output_len = self.encoder.inference(text, aux_input["x_lengths"])
outputs = self.neural_hmm.inference(
encoder_outputs,
encoder_output_len,
sampling_temp=aux_input["sampling_temp"],
max_sampling_time=aux_input["max_sampling_time"],
duration_threshold=aux_input["duration_threshold"],
)
mels, mel_outputs_len = outputs["hmm_outputs"], outputs["hmm_outputs_len"]
mels = self.inverse_normalize(mels)
outputs.update({"model_outputs": mels, "model_outputs_len": mel_outputs_len})
outputs["alignments"] = OverflowUtils.double_pad(outputs["alignments"])
return outputs
@staticmethod
def get_criterion():
return NLLLoss()
@staticmethod
def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
"""Initiate model from config
Args:
config (VitsConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
verbose (bool): If True, print init messages. Defaults to True.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config, verbose)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return NeuralhmmTTS(new_config, ap, tokenizer, speaker_manager)
def load_checkpoint(
self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
def on_init_start(self, trainer):
"""If the current dataset does not have normalisation statistics and initialisation transition_probability it computes them otherwise loads."""
if not os.path.isfile(trainer.config.mel_statistics_parameter_path) or trainer.config.force_generate_statistics:
dataloader = trainer.get_train_dataloader(
training_assets=None, samples=trainer.train_samples, verbose=False
)
print(
f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..."
)
data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start(
dataloader, trainer.config.out_channels, trainer.config.state_per_phone
)
print(
f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}"
)
statistics = {
"mean": data_mean.item(),
"std": data_std.item(),
"init_transition_prob": init_transition_prob.item(),
}
torch.save(statistics, trainer.config.mel_statistics_parameter_path)
else:
print(
f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..."
)
statistics = torch.load(trainer.config.mel_statistics_parameter_path)
data_mean, data_std, init_transition_prob = (
statistics["mean"],
statistics["std"],
statistics["init_transition_prob"],
)
print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}")
trainer.config.flat_start_params["transition_p"] = (
init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob
)
OverflowUtils.update_flat_start_transition(trainer.model, init_transition_prob)
trainer.model.update_mean_std(statistics)
@torch.inference_mode()
def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use, unused-argument
alignments, transition_vectors = outputs["alignments"], outputs["transition_vectors"]
means = torch.stack(outputs["means"], dim=1)
figures = {
"alignment": plot_alignment(alignments[0].exp(), title="Forward alignment", fig_size=(20, 20)),
"log_alignment": plot_alignment(
alignments[0].exp(), title="Forward log alignment", plot_log=True, fig_size=(20, 20)
),
"transition_vectors": plot_alignment(transition_vectors[0], title="Transition vectors", fig_size=(20, 20)),
"mel_from_most_probable_state": plot_spectrogram(
get_spec_from_most_probable_state(alignments[0], means[0]), fig_size=(12, 3)
),
"mel_target": plot_spectrogram(batch["mel_input"][0], fig_size=(12, 3)),
}
# sample one item from the batch -1 will give the smalles item
print(" | > Synthesising audio from the model...")
inference_output = self.inference(
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lenghts": batch["text_lengths"][-1].unsqueeze(0)}
)
figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3))
states = [p[1] for p in inference_output["input_parameters"][0]]
transition_probability_synthesising = [p[2].cpu().numpy() for p in inference_output["output_parameters"][0]]
for i in range((len(transition_probability_synthesising) // 200) + 1):
start = i * 200
end = (i + 1) * 200
figures[f"synthesised_transition_probabilities/{i}"] = plot_transition_probabilities_to_numpy(
states[start:end], transition_probability_synthesising[start:end]
)
audio = ap.inv_melspectrogram(inference_output["model_outputs"][0].T.cpu().numpy())
return figures, {"audios": audio}
def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
): # pylint: disable=unused-argument
"""Log training progress."""
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, self.ap.sample_rate)
def eval_log(
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int
): # pylint: disable=unused-argument
"""Compute and log evaluation metrics."""
# Plot model parameters histograms
if isinstance(logger, TensorboardLogger):
# I don't know if any other loggers supports this
for tag, value in self.named_parameters():
tag = tag.replace(".", "/")
logger.writer.add_histogram(tag, value.data.cpu().numpy(), steps)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)
def test_log(
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs[1], self.ap.sample_rate)
logger.test_figures(steps, outputs[0])
class NLLLoss(nn.Module):
"""Negative log likelihood loss."""
def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use
"""Compute the loss.
Args:
logits (Tensor): [B, T, D]
Returns:
Tensor: [1]
"""
return_dict = {}
return_dict["loss"] = -log_prob.mean()
return return_dict

View File

@ -111,9 +111,6 @@ class Overflow(BaseTTS):
self.register_buffer("mean", torch.tensor(0))
self.register_buffer("std", torch.tensor(1))
# self.mean = nn.Parameter(torch.zeros(1), requires_grad=False)
# self.std = nn.Parameter(torch.ones(1), requires_grad=False)
def update_mean_std(self, statistics_dict: Dict):
self.mean.data = torch.tensor(statistics_dict["mean"])
self.std.data = torch.tensor(statistics_dict["std"])

View File

@ -477,8 +477,8 @@ class VitsArgs(Coqpit):
use_d_vector_file (bool):
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
d_vector_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None.
d_vector_file (List[str]):
List of paths to the files including pre-computed speaker embeddings. Defaults to None.
d_vector_dim (int):
Number of d-vector channels. Defaults to 0.
@ -573,7 +573,7 @@ class VitsArgs(Coqpit):
use_speaker_embedding: bool = False
num_speakers: int = 0
speakers_file: str = None
d_vector_file: str = None
d_vector_file: List[str] = None
speaker_embedding_channels: int = 256
use_d_vector_file: bool = False
d_vector_dim: int = 0

View File

@ -235,6 +235,9 @@ class EmbeddingManager(BaseIDManager):
self.embeddings_by_names.update(embeddings_by_names)
self.embeddings.update(embeddings)
# reset name_to_id to get the right speaker ids
self.name_to_id = {name: i for i, name in enumerate(self.name_to_id)}
def get_embedding_by_clip(self, clip_idx: str) -> List:
"""Get embedding by clip ID.
@ -321,7 +324,7 @@ class EmbeddingManager(BaseIDManager):
self.encoder_config = load_config(config_path)
self.encoder = setup_encoder_model(self.encoder_config)
self.encoder_criterion = self.encoder.load_checkpoint(
self.encoder_config, model_path, eval=True, use_cuda=use_cuda
self.encoder_config, model_path, eval=True, use_cuda=use_cuda, cache=True
)
self.encoder_ap = AudioProcessor(**self.encoder_config.audio)

View File

@ -109,10 +109,6 @@ class SpeakerManager(EmbeddingManager):
if get_from_config_or_model_args_with_default(config, "use_d_vector_file", False):
speaker_manager = SpeakerManager()
if get_from_config_or_model_args_with_default(config, "speakers_file", None):
speaker_manager = SpeakerManager(
d_vectors_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None)
)
if get_from_config_or_model_args_with_default(config, "d_vector_file", None):
speaker_manager = SpeakerManager(
d_vectors_file_path=get_from_config_or_model_args_with_default(config, "d_vector_file", None)

View File

@ -175,9 +175,15 @@ def synthesis(
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
style_mel = style_mel.transpose(1, 2) # [1, time, depth]
language_name = None
if language_id is not None:
language = [k for k, v in model.language_manager.name_to_id.items() if v == language_id]
assert len(language) == 1, "language_id must be a valid language"
language_name = language[0]
# convert text to sequence of token IDs
text_inputs = np.asarray(
model.tokenizer.text_to_ids(text, language=language_id),
model.tokenizer.text_to_ids(text, language=language_name),
dtype=np.int32,
)
# pass tensors to backend

View File

@ -44,8 +44,25 @@ def remove_aux_symbols(text):
def replace_symbols(text, lang="en"):
"""Replace symbols based on the lenguage tag.
Args:
text:
Input text.
lang:
Lenguage identifier. ex: "en", "fr", "pt", "ca".
Returns:
The modified text
example:
input args:
text: "si l'avi cau, diguem-ho"
lang: "ca"
Output:
text: "si lavi cau, diguemho"
"""
text = text.replace(";", ",")
text = text.replace("-", " ")
text = text.replace("-", " ") if lang != "ca" else text.replace("-", "")
text = text.replace(":", ",")
if lang == "en":
text = text.replace("&", " and ")
@ -53,6 +70,9 @@ def replace_symbols(text, lang="en"):
text = text.replace("&", " et ")
elif lang == "pt":
text = text.replace("&", " e ")
elif lang == "ca":
text = text.replace("&", " i ")
text = text.replace("'", "")
return text

View File

@ -114,7 +114,7 @@ class BasePhonemizer(abc.ABC):
return self._punctuator.restore(phonemized, punctuations)[0]
return phonemized[0]
def phonemize(self, text: str, separator="|") -> str:
def phonemize(self, text: str, separator="|", language: str = None) -> str: # pylint: disable=unused-argument
"""Returns the `text` phonemized for the given language
Args:

View File

@ -1,9 +1,10 @@
import logging
import re
import subprocess
from distutils.version import LooseVersion
from typing import Dict, List
from packaging.version import Version
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
from TTS.tts.utils.text.punctuation import Punctuation
@ -14,9 +15,16 @@ def is_tool(name):
return which(name) is not None
# Use a regex pattern to match the espeak version, because it may be
# symlinked to espeak-ng, which moves the version bits to another spot.
espeak_version_pattern = re.compile(r"text-to-speech:\s(?P<version>\d+\.\d+(\.\d+)?)")
def get_espeak_version():
output = subprocess.getoutput("espeak --version")
return output.split()[2]
match = espeak_version_pattern.search(output)
return match.group("version")
def get_espeakng_version():
@ -168,7 +176,7 @@ class ESpeak(BasePhonemizer):
else:
# split with '_'
if self.backend == "espeak":
if LooseVersion(self.backend_version) >= LooseVersion("1.48.15"):
if Version(self.backend_version) >= Version("1.48.15"):
args.append("--ipa=1")
else:
args.append("--ipa=3")

View File

@ -43,7 +43,7 @@ class JA_JP_Phonemizer(BasePhonemizer):
return separator.join(ph)
return ph
def phonemize(self, text: str, separator="|") -> str:
def phonemize(self, text: str, separator="|", language=None) -> str:
"""Custom phonemize for JP_JA
Skip pre-post processing steps used by the other phonemizers.

View File

@ -40,7 +40,7 @@ class KO_KR_Phonemizer(BasePhonemizer):
return separator.join(ph)
return ph
def phonemize(self, text: str, separator: str = "", character: str = "hangeul") -> str:
def phonemize(self, text: str, separator: str = "", character: str = "hangeul", language=None) -> str:
return self._phonemize(text, separator, character)
@staticmethod

View File

@ -14,30 +14,40 @@ class MultiPhonemizer:
TODO: find a way to pass custom kwargs to the phonemizers
"""
lang_to_phonemizer_name = DEF_LANG_TO_PHONEMIZER
language = "multi-lingual"
lang_to_phonemizer = {}
def __init__(self, custom_lang_to_phonemizer: Dict = {}) -> None: # pylint: disable=dangerous-default-value
self.lang_to_phonemizer_name.update(custom_lang_to_phonemizer)
def __init__(self, lang_to_phonemizer_name: Dict = {}) -> None: # pylint: disable=dangerous-default-value
for k, v in lang_to_phonemizer_name.items():
if v == "" and k in DEF_LANG_TO_PHONEMIZER.keys():
lang_to_phonemizer_name[k] = DEF_LANG_TO_PHONEMIZER[k]
elif v == "":
raise ValueError(f"Phonemizer wasn't set for language {k} and doesn't have a default.")
self.lang_to_phonemizer_name = lang_to_phonemizer_name
self.lang_to_phonemizer = self.init_phonemizers(self.lang_to_phonemizer_name)
@staticmethod
def init_phonemizers(lang_to_phonemizer_name: Dict) -> Dict:
lang_to_phonemizer = {}
for k, v in lang_to_phonemizer_name.items():
phonemizer = get_phonemizer_by_name(v, language=k)
lang_to_phonemizer[k] = phonemizer
lang_to_phonemizer[k] = get_phonemizer_by_name(v, language=k)
return lang_to_phonemizer
@staticmethod
def name():
return "multi-phonemizer"
def phonemize(self, text, language, separator="|"):
def phonemize(self, text, separator="|", language=""):
if language == "":
raise ValueError("Language must be set for multi-phonemizer to phonemize.")
return self.lang_to_phonemizer[language].phonemize(text, separator)
def supported_languages(self) -> List:
return list(self.lang_to_phonemizer_name.keys())
return list(self.lang_to_phonemizer.keys())
def print_logs(self, level: int = 0):
indent = "\t" * level
print(f"{indent}| > phoneme language: {self.supported_languages()}")
print(f"{indent}| > phoneme backend: {self.name()}")
# if __name__ == "__main__":
@ -48,7 +58,7 @@ class MultiPhonemizer:
# "zh-cn": "这是中国的例子",
# }
# phonemes = {}
# ph = MultiPhonemizer()
# ph = MultiPhonemizer({"tr": "espeak", "en-us": "", "de": "gruut", "zh-cn": ""})
# for lang, text in texts.items():
# phoneme = ph.phonemize(text, lang)
# phonemes[lang] = phoneme

View File

@ -3,6 +3,7 @@ from typing import Callable, Dict, List, Union
from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer
from TTS.utils.generic_utils import get_import_path, import_class
@ -106,7 +107,7 @@ class TTSTokenizer:
if self.text_cleaner is not None:
text = self.text_cleaner(text)
if self.use_phonemes:
text = self.phonemizer.phonemize(text, separator="")
text = self.phonemizer.phonemize(text, separator="", language=language)
if self.add_blank:
text = self.intersperse_blank_char(text, True)
if self.use_eos_bos:
@ -182,21 +183,29 @@ class TTSTokenizer:
# init phonemizer
phonemizer = None
if config.use_phonemes:
phonemizer_kwargs = {"language": config.phoneme_language}
if "phonemizer" in config and config.phonemizer:
phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs)
if "phonemizer" in config and config.phonemizer == "multi_phonemizer":
lang_to_phonemizer_name = {}
for dataset in config.datasets:
if dataset.language != "":
lang_to_phonemizer_name[dataset.language] = dataset.phonemizer
else:
raise ValueError("Multi phonemizer requires language to be set for each dataset.")
phonemizer = MultiPhonemizer(lang_to_phonemizer_name)
else:
try:
phonemizer = get_phonemizer_by_name(
DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs
)
new_config.phonemizer = phonemizer.name()
except KeyError as e:
raise ValueError(
f"""No phonemizer found for language {config.phoneme_language}.
You may need to install a third party library for this language."""
) from e
phonemizer_kwargs = {"language": config.phoneme_language}
if "phonemizer" in config and config.phonemizer:
phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs)
else:
try:
phonemizer = get_phonemizer_by_name(
DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs
)
new_config.phonemizer = phonemizer.name()
except KeyError as e:
raise ValueError(
f"""No phonemizer found for language {config.phoneme_language}.
You may need to install a third party library for this language."""
) from e
return (
TTSTokenizer(

View File

@ -123,6 +123,39 @@ def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False):
return fig
def plot_avg_energy(energy, chars, fig_size=(30, 10), output_fig=False):
"""Plot energy curves on top of the input characters.
Args:
energy (np.array): energy values.
chars (str): Characters to place to the x-axis.
Shapes:
energy: :math:`(T,)`
"""
old_fig_size = plt.rcParams["figure.figsize"]
if fig_size is not None:
plt.rcParams["figure.figsize"] = fig_size
fig, ax = plt.subplots()
x = np.array(range(len(chars)))
my_xticks = chars
plt.xticks(x, my_xticks)
ax.set_xlabel("characters")
ax.set_ylabel("freq")
ax2 = ax.twinx()
ax2.plot(energy, linewidth=5.0, color="red")
ax2.set_ylabel("energy")
plt.rcParams["figure.figsize"] = old_fig_size
if not output_fig:
plt.close()
return fig
def visualize(
alignment,
postnet_output,

View File

@ -4,7 +4,7 @@ import librosa
import numpy as np
import scipy
import soundfile as sf
from librosa import pyin
from librosa import magphase, pyin
# For using kwargs
# pylint: disable=unused-argument
@ -303,6 +303,27 @@ def compute_f0(
return f0
def compute_energy(y: np.ndarray, **kwargs) -> np.ndarray:
"""Compute energy of a waveform using the same parameters used for computing melspectrogram.
Args:
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
Returns:
np.ndarray: energy. Shape :math:`[T_energy,]`. :math:`T_energy == T_wav / hop_length`
Examples:
>>> WAV_FILE = filename = librosa.util.example_audio_file()
>>> from TTS.config import BaseAudioConfig
>>> from TTS.utils.audio import AudioProcessor
>>> conf = BaseAudioConfig()
>>> ap = AudioProcessor(**conf)
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
>>> energy = ap.compute_energy(wav)
"""
x = stft(y=y, **kwargs)
mag, _ = magphase(x)
energy = np.sqrt(np.sum(mag**2, axis=0))
return energy
### Audio Processing ###
def find_endpoint(
*,

View File

@ -339,10 +339,18 @@ class ModelManager(object):
sub_conf = sub_conf[fd]
else:
return
sub_conf[field_names[-1]] = new_path
if isinstance(sub_conf[field_names[-1]], list):
sub_conf[field_names[-1]] = [new_path]
else:
sub_conf[field_names[-1]] = new_path
else:
# field name points to a top-level field
config[field_name] = new_path
if not field_name in config:
return
if isinstance(config[field_name], list):
config[field_name] = [new_path]
else:
config[field_name] = new_path
config.save_json(config_path)
@staticmethod

View File

@ -187,7 +187,7 @@ class Synthesizer(object):
text (str): input text.
speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "".
language_name (str, optional): language id for multi-language models. Defaults to "".
speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None.
speaker_wav (Union[str, List[str]], optional): path to the speaker wav for voice cloning. Defaults to None.
style_wav ([type], optional): style waveform for GST. Defaults to None.
style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None.
reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None.
@ -242,7 +242,7 @@ class Synthesizer(object):
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
)
# handle multi-lingaul
# handle multi-lingual
language_id = None
if self.tts_languages_file or (
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None

View File

@ -113,7 +113,7 @@ def formatter(root_path, manifest_file, **kwargs): # pylint: disable=unused-arg
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0])
text = cols[1]
items.append({"text":text, "audio_file":wav_file, "speaker_name":speaker_name})
items.append({"text":text, "audio_file":wav_file, "speaker_name":speaker_name, "root_path": root_path})
return items
# load training samples

View File

@ -126,4 +126,13 @@ Here is an example for a single speaker model.
tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False)
# Run TTS
tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)
```
Example voice cloning with YourTTS in English, French and Portuguese:
```python
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="output.wav")
tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav")
```

View File

@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools", "wheel", "cython==0.29.28", "numpy==1.21.6"]
requires = ["setuptools", "wheel", "cython==0.29.28", "numpy==1.21.6", "packaging"]
[flake8]
max-line-length=120
@ -30,4 +30,4 @@ exclude = '''
[tool.isort]
line_length = 120
profile = "black"
multi_line_output = 3
multi_line_output = 3

View File

@ -0,0 +1,102 @@
import os
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig
from TTS.tts.configs.fastspeech2_config import Fastspeech2Config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from TTS.utils.manage import ModelManager
output_path = os.path.dirname(os.path.abspath(__file__))
# init configs
dataset_config = BaseDatasetConfig(
formatter="ljspeech",
meta_file_train="metadata.csv",
# meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
path=os.path.join(output_path, "../LJSpeech-1.1/"),
)
audio_config = BaseAudioConfig(
sample_rate=22050,
do_trim_silence=True,
trim_db=60.0,
signal_norm=False,
mel_fmin=0.0,
mel_fmax=8000,
spec_gain=1.0,
log_func="np.log",
ref_level_db=20,
preemphasis=0.0,
)
config = Fastspeech2Config(
run_name="fastspeech2_ljspeech",
audio=audio_config,
batch_size=32,
eval_batch_size=16,
num_loader_workers=8,
num_eval_loader_workers=4,
compute_input_seq_cache=True,
compute_f0=True,
f0_cache_path=os.path.join(output_path, "f0_cache"),
compute_energy=True,
energy_cache_path=os.path.join(output_path, "energy_cache"),
run_eval=True,
test_delay_epochs=-1,
epochs=1000,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
precompute_num_workers=4,
print_step=50,
print_eval=False,
mixed_precision=False,
max_seq_len=500000,
output_path=output_path,
datasets=[dataset_config],
)
# compute alignments
if not config.model_args.use_aligner:
manager = ModelManager()
model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA")
# TODO: make compute_attention python callable
os.system(
f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true"
)
# INITIALIZE THE AUDIO PROCESSOR
# Audio processor is used for feature extraction and audio I/O.
# It mainly serves to the dataloader and the training loggers.
ap = AudioProcessor.init_from_config(config)
# INITIALIZE THE TOKENIZER
# Tokenizer is used to convert text to sequences of token IDs.
# If characters are not defined in the config, default characters are passed to the config
tokenizer, config = TTSTokenizer.init_from_config(config)
# LOAD DATA SAMPLES
# Each sample is a list of ```[text, audio_file_path, speaker_name]```
# You can define your custom sample loader returning the list of samples.
# Or define your custom formatter and pass it to the `load_tts_samples`.
# Check `TTS.tts.datasets.load_tts_samples` for more details.
train_samples, eval_samples = load_tts_samples(
dataset_config,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# init the model
model = ForwardTTS(config, ap, tokenizer, speaker_manager=None)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
trainer.fit()

View File

@ -0,0 +1,96 @@
import os
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.neuralhmm_tts_config import NeuralhmmTTSConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.neuralhmm_tts import NeuralhmmTTS
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
output_path = os.path.dirname(os.path.abspath(__file__))
# init configs
dataset_config = BaseDatasetConfig(
formatter="ljspeech", meta_file_train="metadata.csv", path=os.path.join("data", "LJSpeech-1.1/")
)
audio_config = BaseAudioConfig(
sample_rate=22050,
do_trim_silence=True,
trim_db=60.0,
signal_norm=False,
mel_fmin=0.0,
mel_fmax=8000,
spec_gain=1.0,
log_func="np.log",
ref_level_db=20,
preemphasis=0.0,
)
config = NeuralhmmTTSConfig( # This is the config that is saved for the future use
run_name="neuralhmmtts_ljspeech",
audio=audio_config,
batch_size=32,
eval_batch_size=16,
num_loader_workers=4,
num_eval_loader_workers=4,
run_eval=True,
test_delay_epochs=-1,
epochs=1000,
text_cleaner="phoneme_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
precompute_num_workers=8,
mel_statistics_parameter_path=os.path.join(output_path, "lj_parameters.pt"),
force_generate_statistics=False,
print_step=1,
print_eval=True,
mixed_precision=True,
output_path=output_path,
datasets=[dataset_config],
)
# INITIALIZE THE AUDIO PROCESSOR
# Audio processor is used for feature extraction and audio I/O.
# It mainly serves to the dataloader and the training loggers.
ap = AudioProcessor.init_from_config(config)
# INITIALIZE THE TOKENIZER
# Tokenizer is used to convert text to sequences of token IDs.
# If characters are not defined in the config, default characters are passed to the config
tokenizer, config = TTSTokenizer.init_from_config(config)
# LOAD DATA SAMPLES
# Each sample is a list of ```[text, audio_file_path, speaker_name]```
# You can define your custom sample loader returning the list of samples.
# Or define your custom formatter and pass it to the `load_tts_samples`.
# Check `TTS.tts.datasets.load_tts_samples` for more details.
train_samples, eval_samples = load_tts_samples(
dataset_config,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# INITIALIZE THE MODEL
# Models take a config object and a speaker manager as input
# Config defines the details of the model like the number of layers, the size of the embedding, etc.
# Speaker manager is used by multi-speaker models.
model = NeuralhmmTTS(config, ap, tokenizer)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
gpu=1,
)
trainer.fit()

View File

@ -0,0 +1,126 @@
import os
from glob import glob
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits, VitsArgs, VitsAudioConfig
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
output_path = "/media/julian/Workdisk/train"
mailabs_path = "/home/julian/workspace/mailabs/**"
dataset_paths = glob(mailabs_path)
dataset_config = [
BaseDatasetConfig(
formatter="mailabs",
meta_file_train=None,
path=path,
language=path.split("/")[-1], # language code is the folder name
)
for path in dataset_paths
]
audio_config = VitsAudioConfig(
sample_rate=16000,
win_length=1024,
hop_length=256,
num_mels=80,
mel_fmin=0,
mel_fmax=None,
)
vitsArgs = VitsArgs(
use_language_embedding=True,
embedded_language_dim=4,
use_speaker_embedding=True,
use_sdp=False,
)
config = VitsConfig(
model_args=vitsArgs,
audio=audio_config,
run_name="vits_vctk",
use_speaker_embedding=True,
batch_size=32,
eval_batch_size=16,
batch_group_size=0,
num_loader_workers=12,
num_eval_loader_workers=12,
precompute_num_workers=12,
run_eval=True,
test_delay_epochs=-1,
epochs=1000,
text_cleaner="multilingual_cleaners",
use_phonemes=True,
phoneme_language=None,
phonemizer="multi_phonemizer",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
compute_input_seq_cache=True,
print_step=25,
use_language_weighted_sampler=True,
print_eval=False,
mixed_precision=False,
min_audio_len=audio_config.sample_rate,
max_audio_len=audio_config.sample_rate * 10,
output_path=output_path,
datasets=dataset_config,
test_sentences=[
[
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"mary_ann",
None,
"en-us",
],
[
"Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
"ezwa",
None,
"fr-fr",
],
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, "de-de"],
["Я думаю, что этот стартап действительно удивительный.", "nikolaev", None, "ru"],
],
)
# force the convertion of the custom characters to a config attribute
config.from_dict(config.to_dict())
# init audio processor
ap = AudioProcessor(**config.audio.to_dict())
# load training samples
train_samples, eval_samples = load_tts_samples(
dataset_config,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers
language_manager = LanguageManager(config=config)
config.model_args.num_languages = language_manager.num_languages
# INITIALIZE THE TOKENIZER
# Tokenizer is used to convert text to sequences of token IDs.
# config is updated with the default characters if not defined in the config.
tokenizer, config = TTSTokenizer.init_from_config(config)
# init model
model = Vits(config, ap, tokenizer, speaker_manager, language_manager)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
trainer.fit()

View File

@ -57,7 +57,25 @@ if not os.path.exists(VCTK_DOWNLOAD_PATH):
# init configs
vctk_config = BaseDatasetConfig(
formatter="vctk", dataset_name="vctk", meta_file_train="", meta_file_val="", path=VCTK_DOWNLOAD_PATH, language="en"
formatter="vctk",
dataset_name="vctk",
meta_file_train="",
meta_file_val="",
path=VCTK_DOWNLOAD_PATH,
language="en",
ignored_speakers=[
"p261",
"p225",
"p294",
"p347",
"p238",
"p234",
"p248",
"p335",
"p245",
"p326",
"p302",
], # Ignore the test speakers to full replicate the paper experiment
)
# Add here all datasets configs, in our case we just want to train with the VCTK dataset then we need to add just VCTK. Note: If you want to added new datasets just added they here and it will automatically compute the speaker embeddings (d-vectors) for this new dataset :)
@ -111,11 +129,11 @@ model_args = VitsArgs(
use_d_vector_file=True,
d_vector_dim=512,
num_layers_text_encoder=10,
speaker_encoder_model_path=SPEAKER_ENCODER_CHECKPOINT_PATH,
speaker_encoder_config_path=SPEAKER_ENCODER_CONFIG_PATH,
resblock_type_decoder="2", # On the paper, we accidentally trained the YourTTS using ResNet blocks type 2, if you like you can use the ResNet blocks type 1 like the VITS model
# Usefull parameters to enable the Speaker Consistency Loss (SCL) discribed in the paper
# use_speaker_encoder_as_loss=True,
# speaker_encoder_model_path=SPEAKER_ENCODER_CHECKPOINT_PATH,
# speaker_encoder_config_path=SPEAKER_ENCODER_CONFIG_PATH,
# Usefull parameters to the enable multilingual training
# use_language_embedding=True,
# embedded_language_dim=4,
@ -207,6 +225,7 @@ config = VitsConfig(
use_weighted_sampler=True,
# Ensures that all speakers are seen in the training batch equally no matter how many samples each speaker has
weighted_sampler_attrs={"speaker_name": 1.0},
weighted_sampler_multipliers={},
# It defines the Speaker Consistency Loss (SCL) α to 9 like the paper
speaker_encoder_loss_alpha=9.0,
)

View File

@ -14,6 +14,7 @@ tqdm
anyascii
pyyaml
fsspec>=2021.04.0
packaging
# deps for examples
flask
# deps for inference
@ -24,7 +25,7 @@ pandas
# deps for training
matplotlib
# coqui stack
trainer
trainer==0.0.20
# config management
coqpit>=0.0.16
# chinese g2p deps

View File

@ -23,7 +23,7 @@
import os
import subprocess
import sys
from distutils.version import LooseVersion
from packaging.version import Version
import numpy
import setuptools.command.build_py
@ -31,7 +31,8 @@ import setuptools.command.develop
from Cython.Build import cythonize
from setuptools import Extension, find_packages, setup
if LooseVersion(sys.version) < LooseVersion("3.7") or LooseVersion(sys.version) >= LooseVersion("3.11"):
python_version = sys.version.split()[0]
if Version(python_version) < Version("3.7") or Version(python_version) >= Version("3.11"):
raise RuntimeError("TTS requires python >= 3.7 and < 3.11 " "but your Python version is {}".format(sys.version))

View File

@ -1,10 +1,12 @@
import os
import unittest
from tests import get_tests_output_path
from tests import get_tests_data_path, get_tests_output_path
from TTS.api import TTS
OUTPUT_PATH = os.path.join(get_tests_output_path(), "test_python_api.wav")
cloning_test_wav_path = os.path.join(get_tests_data_path(), "ljspeech/wavs/LJ001-0028.wav")
class TTSTest(unittest.TestCase):
@ -34,3 +36,9 @@ class TTSTest(unittest.TestCase):
self.assertTrue(tts.is_multi_lingual)
self.assertGreater(len(tts.speakers), 1)
self.assertGreater(len(tts.languages), 1)
@staticmethod
def test_voice_cloning(self):
tts = TTS()
tts.load_model_by_name("tts_models/multilingual/multi-dataset/your_tts")
tts.tts_to_file("Hello world!", speaker_wav=cloning_test_wav_path, language="en", file_path=OUTPUT_PATH)

View File

@ -1,7 +1,9 @@
import unittest
from distutils.version import LooseVersion
from packaging.version import Version
from TTS.tts.utils.text.phonemizers import ESpeak, Gruut, JA_JP_Phonemizer, ZH_CN_Phonemizer
from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer
EXAMPLE_TEXTs = [
"Recent research at Harvard has shown meditating",
@ -39,7 +41,7 @@ class TestEspeakPhonemizer(unittest.TestCase):
def setUp(self):
self.phonemizer = ESpeak(language="en-us", backend="espeak")
if LooseVersion(self.phonemizer.backend_version) >= LooseVersion("1.48.15"):
if Version(self.phonemizer.backend_version) >= Version("1.48.15"):
target_phonemes = EXPECTED_ESPEAK_v1_48_15_PHONEMES
else:
target_phonemes = EXPECTED_ESPEAK_PHONEMES
@ -51,7 +53,7 @@ class TestEspeakPhonemizer(unittest.TestCase):
# multiple punctuations
text = "Be a voice, not an! echo?"
gt = "biː ɐ vˈɔɪs, nˈɑːt ɐn! ˈɛkoʊ?"
if LooseVersion(self.phonemizer.backend_version) >= LooseVersion("1.48.15"):
if Version(self.phonemizer.backend_version) >= Version("1.48.15"):
gt = "biː ɐ vˈɔɪs, nˈɑːt æn! ˈɛkoʊ?"
output = self.phonemizer.phonemize(text, separator="|")
output = output.replace("|", "")
@ -60,7 +62,7 @@ class TestEspeakPhonemizer(unittest.TestCase):
# not ending with punctuation
text = "Be a voice, not an! echo"
gt = "biː ɐ vˈɔɪs, nˈɑːt ɐn! ˈɛkoʊ"
if LooseVersion(self.phonemizer.backend_version) >= LooseVersion("1.48.15"):
if Version(self.phonemizer.backend_version) >= Version("1.48.15"):
gt = "biː ɐ vˈɔɪs, nˈɑːt æn! ˈɛkoʊ"
output = self.phonemizer.phonemize(text, separator="")
self.assertEqual(output, gt)
@ -68,7 +70,7 @@ class TestEspeakPhonemizer(unittest.TestCase):
# extra space after the sentence
text = "Be a voice, not an! echo. "
gt = "biː ɐ vˈɔɪs, nˈɑːt ɐn! ˈɛkoʊ."
if LooseVersion(self.phonemizer.backend_version) >= LooseVersion("1.48.15"):
if Version(self.phonemizer.backend_version) >= Version("1.48.15"):
gt = "biː ɐ vˈɔɪs, nˈɑːt æn! ˈɛkoʊ."
output = self.phonemizer.phonemize(text, separator="")
self.assertEqual(output, gt)
@ -226,3 +228,46 @@ class TestZH_CN_Phonemizer(unittest.TestCase):
def test_is_available(self):
self.assertTrue(self.phonemizer.is_available())
class TestMultiPhonemizer(unittest.TestCase):
def setUp(self):
self.phonemizer = MultiPhonemizer({"tr": "espeak", "en-us": "", "de": "gruut", "zh-cn": ""})
def test_phonemize(self):
# Enlish espeak
text = "Be a voice, not an! echo?"
gt = "biː ɐ vˈɔɪs, nˈɑːt æn! ˈɛkoʊ?"
output = self.phonemizer.phonemize(text, separator="|", language="en-us")
output = output.replace("|", "")
self.assertEqual(output, gt)
# German gruut
text = "Hallo, das ist ein Deutches Beipiel!"
gt = "haloː, das ɪst aeːn dɔɔʏ̯tçəs bəʔiːpiːl!"
output = self.phonemizer.phonemize(text, separator="|", language="de")
output = output.replace("|", "")
self.assertEqual(output, gt)
def test_phonemizer_initialization(self):
# test with unsupported language
with self.assertRaises(ValueError):
MultiPhonemizer({"tr": "espeak", "xx": ""})
# test with unsupported phonemizer
with self.assertRaises(ValueError):
MultiPhonemizer({"tr": "espeak", "fr": "xx"})
def test_sub_phonemizers(self):
for lang in self.phonemizer.lang_to_phonemizer_name.keys():
self.assertEqual(lang, self.phonemizer.lang_to_phonemizer[lang].language)
self.assertEqual(
self.phonemizer.lang_to_phonemizer_name[lang], self.phonemizer.lang_to_phonemizer[lang].name()
)
def test_name(self):
self.assertEqual(self.phonemizer.name(), "multi-phonemizer")
def test_get_supported_languages(self):
self.assertIsInstance(self.phonemizer.supported_languages(), list)

View File

@ -0,0 +1,95 @@
import glob
import json
import os
import shutil
from trainer import get_last_checkpoint
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.fastspeech2_config import Fastspeech2Config
config_path = os.path.join(get_tests_output_path(), "fast_pitch_speaker_emb_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
audio_config = BaseAudioConfig(
sample_rate=22050,
do_trim_silence=True,
trim_db=60.0,
signal_norm=False,
mel_fmin=0.0,
mel_fmax=8000,
spec_gain=1.0,
log_func="np.log",
ref_level_db=20,
preemphasis=0.0,
)
config = Fastspeech2Config(
audio=audio_config,
batch_size=8,
eval_batch_size=8,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
f0_cache_path="tests/data/ljspeech/f0_cache/",
compute_f0=True,
compute_energy=True,
energy_cache_path="tests/data/ljspeech/f0_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
use_speaker_embedding=True,
test_sentences=[
"Be a voice, not an echo.",
],
)
config.audio.do_trim_silence = True
config.use_speaker_embedding = True
config.model_args.use_speaker_embedding = True
config.audio.trim_db = 60
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.formatter ljspeech_test "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# Inference using TTS API
continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
continue_speakers_path = os.path.join(continue_path, "speakers.json")
# Check integrity of the config
with open(continue_config_path, "r", encoding="utf-8") as f:
config_loaded = json.load(f)
assert config_loaded["characters"] is not None
assert config_loaded["output_path"] in continue_path
assert config_loaded["test_delay_epochs"] == 0
# Load the model and run inference
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
run_cli(inference_command)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -0,0 +1,94 @@
import glob
import json
import os
import shutil
from trainer import get_last_checkpoint
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.fastspeech2_config import Fastspeech2Config
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
audio_config = BaseAudioConfig(
sample_rate=22050,
do_trim_silence=True,
trim_db=60.0,
signal_norm=False,
mel_fmin=0.0,
mel_fmax=8000,
spec_gain=1.0,
log_func="np.log",
ref_level_db=20,
preemphasis=0.0,
)
config = Fastspeech2Config(
audio=audio_config,
batch_size=8,
eval_batch_size=8,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
f0_cache_path="tests/data/ljspeech/f0_cache/",
compute_f0=True,
compute_energy=True,
energy_cache_path="tests/data/ljspeech/f0_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
test_sentences=[
"Be a voice, not an echo.",
],
use_speaker_embedding=False,
)
config.audio.do_trim_silence = True
config.use_speaker_embedding = False
config.model_args.use_speaker_embedding = False
config.audio.trim_db = 60
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.formatter ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# Inference using TTS API
continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
# Check integrity of the config
with open(continue_config_path, "r", encoding="utf-8") as f:
config_loaded = json.load(f)
assert config_loaded["characters"] is not None
assert config_loaded["output_path"] in continue_path
assert config_loaded["test_delay_epochs"] == 0
# Load the model and run inference
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
run_cli(inference_command)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -0,0 +1,92 @@
import glob
import json
import os
import shutil
import torch
from trainer import get_last_checkpoint
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs.neuralhmm_tts_config import NeuralhmmTTSConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
parameter_path = os.path.join(get_tests_output_path(), "lj_parameters.pt")
torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path)
config = NeuralhmmTTSConfig(
batch_size=3,
eval_batch_size=3,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="phoneme_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
run_eval=True,
test_delay_epochs=-1,
mel_statistics_parameter_path=parameter_path,
epochs=1,
print_step=1,
test_sentences=[
"Be a voice, not an echo.",
],
print_eval=True,
max_sampling_time=50,
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.save_json(config_path)
# train the model for one epoch when mel parameters exists
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.formatter ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.test_delay_epochs 0 "
)
run_cli(command_train)
# train the model for one epoch when mel parameters have to be computed from the dataset
if os.path.exists(parameter_path):
os.remove(parameter_path)
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.formatter ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.test_delay_epochs 0 "
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# Inference using TTS API
continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
# Check integrity of the config
with open(continue_config_path, "r", encoding="utf-8") as f:
config_loaded = json.load(f)
assert config_loaded["characters"] is not None
assert config_loaded["output_path"] in continue_path
assert config_loaded["test_delay_epochs"] == 0
# Load the model and run inference
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
run_cli(inference_command)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -210,7 +210,7 @@ class TestVits(unittest.TestCase):
num_chars=32,
use_d_vector_file=True,
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")],
)
config = VitsConfig(model_args=args)
model = Vits.init_from_config(config, verbose=False).to(device)
@ -355,7 +355,7 @@ class TestVits(unittest.TestCase):
num_chars=32,
use_d_vector_file=True,
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")],
)
config = VitsConfig(model_args=args)
model = Vits.init_from_config(config, verbose=False).to(device)
@ -587,7 +587,7 @@ class TestVits(unittest.TestCase):
num_chars=32,
use_d_vector_file=True,
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")],
)
)
model = Vits.init_from_config(config, verbose=False).to(device)

View File

@ -33,7 +33,7 @@ config.audio.trim_db = 60
# active multispeaker d-vec mode
config.model_args.use_d_vector_file = True
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.d_vector_file = ["tests/data/ljspeech/speakers.json"]
config.model_args.d_vector_dim = 256

View File

@ -63,8 +63,8 @@ config.use_speaker_embedding = False
# active multispeaker d-vec mode
config.model_args.use_d_vector_file = True
config.use_d_vector_file = True
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.d_vector_file = ["tests/data/ljspeech/speakers.json"]
config.d_vector_file = ["tests/data/ljspeech/speakers.json"]
config.model_args.d_vector_dim = 256
config.d_vector_dim = 256