mirror of https://github.com/coqui-ai/TTS.git
implement DeepSpeechConfig
This commit is contained in:
parent
e033b3d7bb
commit
5ee01e5624
|
@ -0,0 +1,17 @@
|
|||
import importlib
|
||||
import os
|
||||
from inspect import isclass
|
||||
|
||||
# import all files under configs/
|
||||
configs_dir = os.path.dirname(__file__)
|
||||
for file in os.listdir(configs_dir):
|
||||
path = os.path.join(configs_dir, file)
|
||||
if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
|
||||
config_name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
module = importlib.import_module("TTS.stt.configs." + config_name)
|
||||
for attribute_name in dir(module):
|
||||
attribute = getattr(module, attribute_name)
|
||||
|
||||
if isclass(attribute):
|
||||
# Add the class to this package's variables
|
||||
globals()[attribute_name] = attribute
|
|
@ -0,0 +1,112 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List
|
||||
|
||||
from TTS.stt.configs.shared_configs import BaseSTTConfig
|
||||
from TTS.stt.models.deep_speech import DeepSpeechArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepSpeechConfig(BaseSTTConfig):
|
||||
"""Defines parameters for VITS End2End TTS model.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name. Do not change unless you know what you are doing.
|
||||
|
||||
model_args (VitsArgs):
|
||||
Model architecture arguments. Defaults to `VitsArgs()`.
|
||||
|
||||
grad_clip (List):
|
||||
Gradient clipping thresholds for each optimizer. Defaults to `[5.0, 5.0]`.
|
||||
|
||||
lr (float):
|
||||
Initial learning rate. Defaults to 0.0002.
|
||||
|
||||
lr_scheduler_gen (str):
|
||||
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
||||
`ExponentialLR`.
|
||||
|
||||
lr_scheduler_gen_params (dict):
|
||||
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
||||
|
||||
lr_scheduler_disc (str):
|
||||
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
||||
`ExponentialLR`.
|
||||
|
||||
lr_scheduler_disc_params (dict):
|
||||
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
||||
|
||||
scheduler_after_epoch (bool):
|
||||
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
|
||||
|
||||
optimizer (str):
|
||||
Name of the optimizer to use with both the generator and the discriminator networks. One of the
|
||||
`torch.optim.*`. Defaults to `AdamW`.
|
||||
|
||||
kl_loss_alpha (float):
|
||||
Loss weight for KL loss. Defaults to 1.0.
|
||||
|
||||
disc_loss_alpha (float):
|
||||
Loss weight for the discriminator loss. Defaults to 1.0.
|
||||
|
||||
gen_loss_alpha (float):
|
||||
Loss weight for the generator loss. Defaults to 1.0.
|
||||
|
||||
feat_loss_alpha (float):
|
||||
Loss weight for the feature matching loss. Defaults to 1.0.
|
||||
|
||||
mel_loss_alpha (float):
|
||||
Loss weight for the mel loss. Defaults to 45.0.
|
||||
|
||||
return_wav (bool):
|
||||
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
|
||||
|
||||
compute_linear_spec (bool):
|
||||
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
||||
|
||||
sort_by_audio_len (bool):
|
||||
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`.
|
||||
|
||||
min_seq_len (int):
|
||||
Minimum sequnce length to be considered for training. Defaults to `0`.
|
||||
|
||||
max_seq_len (int):
|
||||
Maximum sequnce length to be considered for training. Defaults to `500000`.
|
||||
|
||||
r (int):
|
||||
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
||||
|
||||
add_blank (bool):
|
||||
If true, a blank token is added in between every character. Defaults to `True`.
|
||||
|
||||
test_sentences (List[str]):
|
||||
List of sentences to be used for testing.
|
||||
|
||||
Note:
|
||||
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.stt.configs import DeepSpeechConfig
|
||||
>>> config = DeepSpeechConfig()
|
||||
"""
|
||||
|
||||
model: str = "deep_speech"
|
||||
# model specific params
|
||||
model_args: DeepSpeechArgs = field(default_factory=DeepSpeechArgs)
|
||||
|
||||
# optimizer
|
||||
grad_clip: float = 10
|
||||
lr: float = 0.0001
|
||||
lr_scheduler: str = "ExponentialLR"
|
||||
lr_scheduler_params: Dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
|
||||
scheduler_after_epoch: bool = True
|
||||
optimizer: str = "AdamW"
|
||||
optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})
|
||||
|
||||
# overrides
|
||||
loss_masking: bool = True
|
||||
feature_extractor: str = "MFCC"
|
||||
sort_by_audio_len: bool = True
|
||||
min_seq_len: int = 0
|
||||
max_seq_len: int = 500000
|
|
@ -0,0 +1,90 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List
|
||||
|
||||
from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseSTTConfig(BaseTrainingConfig):
|
||||
"""Shared parameters among all the tts models.
|
||||
|
||||
Args:
|
||||
model_type (str):
|
||||
The type of the model that defines the trainer mode. Do not change this value.
|
||||
|
||||
audio (BaseAudioConfig):
|
||||
Audio processor config object instance.
|
||||
|
||||
enable_eos_bos_chars (bool):
|
||||
enable / disable the use of eos and bos characters.
|
||||
|
||||
vocabulary (Dict[str, int]):
|
||||
vocabulary used by the model
|
||||
|
||||
batch_group_size (int):
|
||||
Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence
|
||||
length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to
|
||||
prevent using the same batches for each epoch.
|
||||
|
||||
loss_masking (bool):
|
||||
enable / disable masking loss values against padded segments of samples in a batch.
|
||||
|
||||
feature_extractor (str):
|
||||
Name of the feature extractor used by the model, one of the supperted values by
|
||||
```TTS.stt.datasets.dataset.FeatureExtractor```. Defaults to None.
|
||||
|
||||
sort_by_audio_len (bool):
|
||||
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
|
||||
|
||||
min_seq_len (int):
|
||||
Minimum sequence length to be used at training.
|
||||
|
||||
max_seq_len (int):
|
||||
Maximum sequence length to be used at training. Larger values result in more VRAM usage.
|
||||
|
||||
train_datasets (List[BaseDatasetConfig]):
|
||||
List of datasets used for training. If multiple datasets are provided, they are merged and used together
|
||||
for training.
|
||||
|
||||
eval_datasets (List[BaseDatasetConfig]):
|
||||
List of datasets used for evaluation. If multiple datasets are provided, they are merged and used together
|
||||
for training.
|
||||
|
||||
optimizer (str):
|
||||
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
|
||||
Defaults to ``.
|
||||
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
|
||||
lr_scheduler (str):
|
||||
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
|
||||
`TTS.utils.training`. Defaults to ``.
|
||||
|
||||
lr_scheduler_params (dict):
|
||||
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
|
||||
"""
|
||||
|
||||
model_type: str = "stt"
|
||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
# phoneme settings
|
||||
enable_eos_bos_chars: bool = False
|
||||
# vocabulary parameters
|
||||
vocabulary: dict = None
|
||||
# training params
|
||||
batch_group_size: int = 0
|
||||
loss_masking: bool = None
|
||||
# dataloading
|
||||
feature_extractor: str = None
|
||||
sort_by_audio_len: bool = True
|
||||
min_seq_len: int = 1
|
||||
max_seq_len: int = float("inf")
|
||||
# dataset
|
||||
train_datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||
eval_datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||
# optimizer
|
||||
optimizer: str = None
|
||||
optimizer_params: dict = None
|
||||
# scheduler
|
||||
lr_scheduler: str = ""
|
||||
lr_scheduler_params: dict = None
|
|
@ -0,0 +1,62 @@
|
|||
import sys
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from TTS.stt.datasets.formatters import *
|
||||
|
||||
|
||||
def split_dataset(items, eval_split_size=None):
|
||||
if not eval_split_size:
|
||||
eval_split_size = min(500, int(len(items) * 0.01))
|
||||
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
|
||||
np.random.seed(0)
|
||||
np.random.shuffle(items)
|
||||
return items[:eval_split_size], items[eval_split_size:]
|
||||
|
||||
|
||||
def load_stt_samples(
|
||||
datasets: Union[Dict, List[Dict]], eval_split=True, eval_split_size: int = None
|
||||
) -> Tuple[List[List], List[List]]:
|
||||
"""Parse the dataset, load the samples as a list and load the attention alignments if provided.
|
||||
|
||||
Args:
|
||||
datasets (Union[Dict, List[Dict]]): A list of dataset configs or a single dataset config.
|
||||
eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate
|
||||
an eval split automatically. Defaults to True.
|
||||
eval_split_size (int, optional): The size of the evaluation split. If None, use pre-defined split size.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[List[List], List[List]: training and evaluation splits of the dataset.
|
||||
"""
|
||||
if not isinstance(datasets, list):
|
||||
datasets = [datasets]
|
||||
meta_data_train_all = []
|
||||
meta_data_eval_all = [] if eval_split else None
|
||||
for dataset in datasets:
|
||||
name = dataset["name"]
|
||||
root_path = dataset["path"]
|
||||
meta_file_train = dataset["meta_file_train"]
|
||||
meta_file_val = dataset["meta_file_val"]
|
||||
# setup the right data processor
|
||||
preprocessor = _get_preprocessor_by_name(name)
|
||||
# load train set
|
||||
meta_data_train = preprocessor(root_path, meta_file_train)
|
||||
# load evaluation split if set
|
||||
if eval_split:
|
||||
if meta_file_val:
|
||||
meta_data_eval = preprocessor(root_path, meta_file_val)
|
||||
else:
|
||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_size)
|
||||
meta_data_eval_all += meta_data_eval
|
||||
meta_data_train_all += meta_data_train
|
||||
return meta_data_train_all, meta_data_eval_all
|
||||
|
||||
|
||||
def _get_preprocessor_by_name(name):
|
||||
"""Returns the respective preprocessing function."""
|
||||
thismodule = sys.modules[__name__]
|
||||
return getattr(thismodule, name.lower())
|
|
@ -0,0 +1,14 @@
|
|||
from TTS.stt.datasets.tokenizer import Tokenizer
|
||||
from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
|
||||
from TTS.utils.generic_utils import find_module
|
||||
|
||||
|
||||
def setup_stt_model(config):
|
||||
print(" > Using model: {}".format(config.model))
|
||||
# fetch the right model implementation.
|
||||
if "base_model" in config and config["base_model"] is not None:
|
||||
MyModel = find_module("TTS.stt.models", config.base_model.lower())
|
||||
else:
|
||||
MyModel = find_module("TTS.stt.models", config.model.lower())
|
||||
model = MyModel(config)
|
||||
return model
|
Loading…
Reference in New Issue