From 5ee01e562428c6aa8e8620226d1045ea770bbe27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 30 Sep 2021 14:43:26 +0000 Subject: [PATCH] implement DeepSpeechConfig --- TTS/stt/__init__.py | 0 TTS/stt/configs/__init__.py | 17 ++++ TTS/stt/configs/deep_speech_config.py | 112 ++++++++++++++++++++++++++ TTS/stt/configs/shared_configs.py | 90 +++++++++++++++++++++ TTS/stt/datasets/__init__.py | 62 ++++++++++++++ TTS/stt/models/__init__.py | 14 ++++ 6 files changed, 295 insertions(+) create mode 100644 TTS/stt/__init__.py create mode 100644 TTS/stt/configs/__init__.py create mode 100644 TTS/stt/configs/deep_speech_config.py create mode 100644 TTS/stt/configs/shared_configs.py create mode 100644 TTS/stt/datasets/__init__.py create mode 100644 TTS/stt/models/__init__.py diff --git a/TTS/stt/__init__.py b/TTS/stt/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/stt/configs/__init__.py b/TTS/stt/configs/__init__.py new file mode 100644 index 00000000..1e61a4ad --- /dev/null +++ b/TTS/stt/configs/__init__.py @@ -0,0 +1,17 @@ +import importlib +import os +from inspect import isclass + +# import all files under configs/ +configs_dir = os.path.dirname(__file__) +for file in os.listdir(configs_dir): + path = os.path.join(configs_dir, file) + if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): + config_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("TTS.stt.configs." + config_name) + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + + if isclass(attribute): + # Add the class to this package's variables + globals()[attribute_name] = attribute diff --git a/TTS/stt/configs/deep_speech_config.py b/TTS/stt/configs/deep_speech_config.py new file mode 100644 index 00000000..3d10766d --- /dev/null +++ b/TTS/stt/configs/deep_speech_config.py @@ -0,0 +1,112 @@ +from dataclasses import dataclass, field +from typing import Dict, List + +from TTS.stt.configs.shared_configs import BaseSTTConfig +from TTS.stt.models.deep_speech import DeepSpeechArgs + + +@dataclass +class DeepSpeechConfig(BaseSTTConfig): + """Defines parameters for VITS End2End TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (VitsArgs): + Model architecture arguments. Defaults to `VitsArgs()`. + + grad_clip (List): + Gradient clipping thresholds for each optimizer. Defaults to `[5.0, 5.0]`. + + lr (float): + Initial learning rate. Defaults to 0.0002. + + lr_scheduler_gen (str): + Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_gen_params (dict): + Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + lr_scheduler_disc (str): + Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_disc_params (dict): + Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + scheduler_after_epoch (bool): + If true, step the schedulers after each epoch else after each step. Defaults to `False`. + + optimizer (str): + Name of the optimizer to use with both the generator and the discriminator networks. One of the + `torch.optim.*`. Defaults to `AdamW`. + + kl_loss_alpha (float): + Loss weight for KL loss. Defaults to 1.0. + + disc_loss_alpha (float): + Loss weight for the discriminator loss. Defaults to 1.0. + + gen_loss_alpha (float): + Loss weight for the generator loss. Defaults to 1.0. + + feat_loss_alpha (float): + Loss weight for the feature matching loss. Defaults to 1.0. + + mel_loss_alpha (float): + Loss weight for the mel loss. Defaults to 45.0. + + return_wav (bool): + If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`. + + compute_linear_spec (bool): + If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. + + sort_by_audio_len (bool): + If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`. + + min_seq_len (int): + Minimum sequnce length to be considered for training. Defaults to `0`. + + max_seq_len (int): + Maximum sequnce length to be considered for training. Defaults to `500000`. + + r (int): + Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. + + add_blank (bool): + If true, a blank token is added in between every character. Defaults to `True`. + + test_sentences (List[str]): + List of sentences to be used for testing. + + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.stt.configs import DeepSpeechConfig + >>> config = DeepSpeechConfig() + """ + + model: str = "deep_speech" + # model specific params + model_args: DeepSpeechArgs = field(default_factory=DeepSpeechArgs) + + # optimizer + grad_clip: float = 10 + lr: float = 0.0001 + lr_scheduler: str = "ExponentialLR" + lr_scheduler_params: Dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + scheduler_after_epoch: bool = True + optimizer: str = "AdamW" + optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01}) + + # overrides + loss_masking: bool = True + feature_extractor: str = "MFCC" + sort_by_audio_len: bool = True + min_seq_len: int = 0 + max_seq_len: int = 500000 diff --git a/TTS/stt/configs/shared_configs.py b/TTS/stt/configs/shared_configs.py new file mode 100644 index 00000000..228fc0a1 --- /dev/null +++ b/TTS/stt/configs/shared_configs.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass, field +from typing import Dict, List + +from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig + + +@dataclass +class BaseSTTConfig(BaseTrainingConfig): + """Shared parameters among all the tts models. + + Args: + model_type (str): + The type of the model that defines the trainer mode. Do not change this value. + + audio (BaseAudioConfig): + Audio processor config object instance. + + enable_eos_bos_chars (bool): + enable / disable the use of eos and bos characters. + + vocabulary (Dict[str, int]): + vocabulary used by the model + + batch_group_size (int): + Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence + length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to + prevent using the same batches for each epoch. + + loss_masking (bool): + enable / disable masking loss values against padded segments of samples in a batch. + + feature_extractor (str): + Name of the feature extractor used by the model, one of the supperted values by + ```TTS.stt.datasets.dataset.FeatureExtractor```. Defaults to None. + + sort_by_audio_len (bool): + If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`. + + min_seq_len (int): + Minimum sequence length to be used at training. + + max_seq_len (int): + Maximum sequence length to be used at training. Larger values result in more VRAM usage. + + train_datasets (List[BaseDatasetConfig]): + List of datasets used for training. If multiple datasets are provided, they are merged and used together + for training. + + eval_datasets (List[BaseDatasetConfig]): + List of datasets used for evaluation. If multiple datasets are provided, they are merged and used together + for training. + + optimizer (str): + Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`. + Defaults to ``. + + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to ``. + + lr_scheduler_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`. + """ + + model_type: str = "stt" + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + # phoneme settings + enable_eos_bos_chars: bool = False + # vocabulary parameters + vocabulary: dict = None + # training params + batch_group_size: int = 0 + loss_masking: bool = None + # dataloading + feature_extractor: str = None + sort_by_audio_len: bool = True + min_seq_len: int = 1 + max_seq_len: int = float("inf") + # dataset + train_datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + eval_datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # optimizer + optimizer: str = None + optimizer_params: dict = None + # scheduler + lr_scheduler: str = "" + lr_scheduler_params: dict = None diff --git a/TTS/stt/datasets/__init__.py b/TTS/stt/datasets/__init__.py new file mode 100644 index 00000000..fe65e6c9 --- /dev/null +++ b/TTS/stt/datasets/__init__.py @@ -0,0 +1,62 @@ +import sys +from collections import Counter +from pathlib import Path +from typing import Dict, List, Tuple, Union + +import numpy as np + +from TTS.stt.datasets.formatters import * + + +def split_dataset(items, eval_split_size=None): + if not eval_split_size: + eval_split_size = min(500, int(len(items) * 0.01)) + assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples." + np.random.seed(0) + np.random.shuffle(items) + return items[:eval_split_size], items[eval_split_size:] + + +def load_stt_samples( + datasets: Union[Dict, List[Dict]], eval_split=True, eval_split_size: int = None +) -> Tuple[List[List], List[List]]: + """Parse the dataset, load the samples as a list and load the attention alignments if provided. + + Args: + datasets (Union[Dict, List[Dict]]): A list of dataset configs or a single dataset config. + eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate + an eval split automatically. Defaults to True. + eval_split_size (int, optional): The size of the evaluation split. If None, use pre-defined split size. + Defaults to None. + + Returns: + Tuple[List[List], List[List]: training and evaluation splits of the dataset. + """ + if not isinstance(datasets, list): + datasets = [datasets] + meta_data_train_all = [] + meta_data_eval_all = [] if eval_split else None + for dataset in datasets: + name = dataset["name"] + root_path = dataset["path"] + meta_file_train = dataset["meta_file_train"] + meta_file_val = dataset["meta_file_val"] + # setup the right data processor + preprocessor = _get_preprocessor_by_name(name) + # load train set + meta_data_train = preprocessor(root_path, meta_file_train) + # load evaluation split if set + if eval_split: + if meta_file_val: + meta_data_eval = preprocessor(root_path, meta_file_val) + else: + meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_size) + meta_data_eval_all += meta_data_eval + meta_data_train_all += meta_data_train + return meta_data_train_all, meta_data_eval_all + + +def _get_preprocessor_by_name(name): + """Returns the respective preprocessing function.""" + thismodule = sys.modules[__name__] + return getattr(thismodule, name.lower()) diff --git a/TTS/stt/models/__init__.py b/TTS/stt/models/__init__.py new file mode 100644 index 00000000..70981a6e --- /dev/null +++ b/TTS/stt/models/__init__.py @@ -0,0 +1,14 @@ +from TTS.stt.datasets.tokenizer import Tokenizer +from TTS.tts.utils.text.symbols import make_symbols, parse_symbols +from TTS.utils.generic_utils import find_module + + +def setup_stt_model(config): + print(" > Using model: {}".format(config.model)) + # fetch the right model implementation. + if "base_model" in config and config["base_model"] is not None: + MyModel = find_module("TTS.stt.models", config.base_model.lower()) + else: + MyModel = find_module("TTS.stt.models", config.model.lower()) + model = MyModel(config) + return model