implement DeepSpeechConfig

This commit is contained in:
Eren Gölge 2021-09-30 14:43:26 +00:00
parent e033b3d7bb
commit 5ee01e5624
6 changed files with 295 additions and 0 deletions

0
TTS/stt/__init__.py Normal file
View File

View File

@ -0,0 +1,17 @@
import importlib
import os
from inspect import isclass
# import all files under configs/
configs_dir = os.path.dirname(__file__)
for file in os.listdir(configs_dir):
path = os.path.join(configs_dir, file)
if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
config_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("TTS.stt.configs." + config_name)
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
if isclass(attribute):
# Add the class to this package's variables
globals()[attribute_name] = attribute

View File

@ -0,0 +1,112 @@
from dataclasses import dataclass, field
from typing import Dict, List
from TTS.stt.configs.shared_configs import BaseSTTConfig
from TTS.stt.models.deep_speech import DeepSpeechArgs
@dataclass
class DeepSpeechConfig(BaseSTTConfig):
"""Defines parameters for VITS End2End TTS model.
Args:
model (str):
Model name. Do not change unless you know what you are doing.
model_args (VitsArgs):
Model architecture arguments. Defaults to `VitsArgs()`.
grad_clip (List):
Gradient clipping thresholds for each optimizer. Defaults to `[5.0, 5.0]`.
lr (float):
Initial learning rate. Defaults to 0.0002.
lr_scheduler_gen (str):
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
`ExponentialLR`.
lr_scheduler_gen_params (dict):
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
lr_scheduler_disc (str):
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
`ExponentialLR`.
lr_scheduler_disc_params (dict):
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
scheduler_after_epoch (bool):
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
optimizer (str):
Name of the optimizer to use with both the generator and the discriminator networks. One of the
`torch.optim.*`. Defaults to `AdamW`.
kl_loss_alpha (float):
Loss weight for KL loss. Defaults to 1.0.
disc_loss_alpha (float):
Loss weight for the discriminator loss. Defaults to 1.0.
gen_loss_alpha (float):
Loss weight for the generator loss. Defaults to 1.0.
feat_loss_alpha (float):
Loss weight for the feature matching loss. Defaults to 1.0.
mel_loss_alpha (float):
Loss weight for the mel loss. Defaults to 45.0.
return_wav (bool):
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
sort_by_audio_len (bool):
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`.
min_seq_len (int):
Minimum sequnce length to be considered for training. Defaults to `0`.
max_seq_len (int):
Maximum sequnce length to be considered for training. Defaults to `500000`.
r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
add_blank (bool):
If true, a blank token is added in between every character. Defaults to `True`.
test_sentences (List[str]):
List of sentences to be used for testing.
Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
Example:
>>> from TTS.stt.configs import DeepSpeechConfig
>>> config = DeepSpeechConfig()
"""
model: str = "deep_speech"
# model specific params
model_args: DeepSpeechArgs = field(default_factory=DeepSpeechArgs)
# optimizer
grad_clip: float = 10
lr: float = 0.0001
lr_scheduler: str = "ExponentialLR"
lr_scheduler_params: Dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
scheduler_after_epoch: bool = True
optimizer: str = "AdamW"
optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})
# overrides
loss_masking: bool = True
feature_extractor: str = "MFCC"
sort_by_audio_len: bool = True
min_seq_len: int = 0
max_seq_len: int = 500000

View File

@ -0,0 +1,90 @@
from dataclasses import dataclass, field
from typing import Dict, List
from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
@dataclass
class BaseSTTConfig(BaseTrainingConfig):
"""Shared parameters among all the tts models.
Args:
model_type (str):
The type of the model that defines the trainer mode. Do not change this value.
audio (BaseAudioConfig):
Audio processor config object instance.
enable_eos_bos_chars (bool):
enable / disable the use of eos and bos characters.
vocabulary (Dict[str, int]):
vocabulary used by the model
batch_group_size (int):
Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence
length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to
prevent using the same batches for each epoch.
loss_masking (bool):
enable / disable masking loss values against padded segments of samples in a batch.
feature_extractor (str):
Name of the feature extractor used by the model, one of the supperted values by
```TTS.stt.datasets.dataset.FeatureExtractor```. Defaults to None.
sort_by_audio_len (bool):
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
min_seq_len (int):
Minimum sequence length to be used at training.
max_seq_len (int):
Maximum sequence length to be used at training. Larger values result in more VRAM usage.
train_datasets (List[BaseDatasetConfig]):
List of datasets used for training. If multiple datasets are provided, they are merged and used together
for training.
eval_datasets (List[BaseDatasetConfig]):
List of datasets used for evaluation. If multiple datasets are provided, they are merged and used together
for training.
optimizer (str):
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
Defaults to ``.
optimizer_params (dict):
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
lr_scheduler (str):
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
`TTS.utils.training`. Defaults to ``.
lr_scheduler_params (dict):
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
"""
model_type: str = "stt"
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
# phoneme settings
enable_eos_bos_chars: bool = False
# vocabulary parameters
vocabulary: dict = None
# training params
batch_group_size: int = 0
loss_masking: bool = None
# dataloading
feature_extractor: str = None
sort_by_audio_len: bool = True
min_seq_len: int = 1
max_seq_len: int = float("inf")
# dataset
train_datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
eval_datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer
optimizer: str = None
optimizer_params: dict = None
# scheduler
lr_scheduler: str = ""
lr_scheduler_params: dict = None

View File

@ -0,0 +1,62 @@
import sys
from collections import Counter
from pathlib import Path
from typing import Dict, List, Tuple, Union
import numpy as np
from TTS.stt.datasets.formatters import *
def split_dataset(items, eval_split_size=None):
if not eval_split_size:
eval_split_size = min(500, int(len(items) * 0.01))
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
np.random.seed(0)
np.random.shuffle(items)
return items[:eval_split_size], items[eval_split_size:]
def load_stt_samples(
datasets: Union[Dict, List[Dict]], eval_split=True, eval_split_size: int = None
) -> Tuple[List[List], List[List]]:
"""Parse the dataset, load the samples as a list and load the attention alignments if provided.
Args:
datasets (Union[Dict, List[Dict]]): A list of dataset configs or a single dataset config.
eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate
an eval split automatically. Defaults to True.
eval_split_size (int, optional): The size of the evaluation split. If None, use pre-defined split size.
Defaults to None.
Returns:
Tuple[List[List], List[List]: training and evaluation splits of the dataset.
"""
if not isinstance(datasets, list):
datasets = [datasets]
meta_data_train_all = []
meta_data_eval_all = [] if eval_split else None
for dataset in datasets:
name = dataset["name"]
root_path = dataset["path"]
meta_file_train = dataset["meta_file_train"]
meta_file_val = dataset["meta_file_val"]
# setup the right data processor
preprocessor = _get_preprocessor_by_name(name)
# load train set
meta_data_train = preprocessor(root_path, meta_file_train)
# load evaluation split if set
if eval_split:
if meta_file_val:
meta_data_eval = preprocessor(root_path, meta_file_val)
else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_size)
meta_data_eval_all += meta_data_eval
meta_data_train_all += meta_data_train
return meta_data_train_all, meta_data_eval_all
def _get_preprocessor_by_name(name):
"""Returns the respective preprocessing function."""
thismodule = sys.modules[__name__]
return getattr(thismodule, name.lower())

View File

@ -0,0 +1,14 @@
from TTS.stt.datasets.tokenizer import Tokenizer
from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
from TTS.utils.generic_utils import find_module
def setup_stt_model(config):
print(" > Using model: {}".format(config.model))
# fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None:
MyModel = find_module("TTS.stt.models", config.base_model.lower())
else:
MyModel = find_module("TTS.stt.models", config.model.lower())
model = MyModel(config)
return model