diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py new file mode 100644 index 00000000..9bd18ff0 --- /dev/null +++ b/TTS/config/shared_configs.py @@ -0,0 +1,258 @@ +from dataclasses import asdict, dataclass +from typing import List, Union + +from coqpit import MISSING, Coqpit, check_argument + + +@dataclass +class BaseAudioConfig(Coqpit): + """Base config to definge audio processing parameters. It is used to initialize + ```TTS.utils.audio.AudioProcessor.``` + + Args: + fft_size (int): + Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024. + win_length (int): + Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match + ```fft_size```. Defaults to 256. + hop_length (int): + Number of audio samples between adjacent STFT columns. Defaults to 1024. + frame_shift_ms (int): + Set ```hop_length``` based on milliseconds and sampling rate. + frame_length_ms (int): + Set ```win_length``` based on milliseconds and sampling rate. + stft_pad_mode (str): + Padding method used in STFT. 'reflect' or 'center'. + sample_rate (int): + Audio sampling rate. Defaults to 22050. + resample (bool): + Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```. + preemphasis (float): + Preemphasis coefficient. Defaults to 0.0. + ref_level_db (int): 20 + Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air. + Defaults to 20. + do_sound_norm (bool): + Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False. + do_trim_silence (bool): + Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```. + trim_db (int): + Silence threshold used for silence trimming. Defaults to 45. + power (float): + Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the + artifacts in the synthesized voice. Defaults to 1.5. + griffin_lim_iters (int): + Number of Griffing Lim iterations. Defaults to 60. + num_mels (int): + Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80. + mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices. + It needs to be adjusted for a dataset. Defaults to 0. + mel_fmax (float): + Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset. + spec_gain (int): + Gain applied when converting amplitude to DB. Defaults to 20. + signal_norm (bool): + enable/disable signal normalization. Defaults to True. + min_level_db (int): + minimum db threshold for the computed melspectrograms. Defaults to -100. + symmetric_norm (bool): + enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else + [0, k], Defaults to True. + max_norm (float): + ```k``` defining the normalization range. Defaults to 4.0. + clip_norm (bool): + enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + stats_path (str): + Path to the computed stats file. Defaults to None. + """ + + # stft parameters + fft_size: int = 1024 + win_length: int = 1024 + hop_length: int = 256 + frame_shift_ms: int = None + frame_length_ms: int = None + stft_pad_mode: str = "reflect" + # audio processing parameters + sample_rate: int = 22050 + resample: bool = False + preemphasis: float = 0.0 + ref_level_db: int = 20 + do_sound_norm: bool = False + log_func = "np.log10" + # silence trimming + do_trim_silence: bool = True + trim_db: int = 45 + # griffin-lim params + power: float = 1.5 + griffin_lim_iters: int = 60 + # mel-spec params + num_mels: int = 80 + mel_fmin: float = 0.0 + mel_fmax: float = None + spec_gain: int = 20 + # normalization params + signal_norm: bool = True + min_level_db: int = -100 + symmetric_norm: bool = True + max_norm: float = 4.0 + clip_norm: bool = True + stats_path: str = None + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("num_mels", c, restricted=True, min_val=10, max_val=2056) + check_argument("fft_size", c, restricted=True, min_val=128, max_val=4058) + check_argument("sample_rate", c, restricted=True, min_val=512, max_val=100000) + check_argument( + "frame_length_ms", + c, + restricted=True, + min_val=10, + max_val=1000, + alternative="win_length", + ) + check_argument("frame_shift_ms", c, restricted=True, min_val=1, max_val=1000, alternative="hop_length") + check_argument("preemphasis", c, restricted=True, min_val=0, max_val=1) + check_argument("min_level_db", c, restricted=True, min_val=-1000, max_val=10) + check_argument("ref_level_db", c, restricted=True, min_val=0, max_val=1000) + check_argument("power", c, restricted=True, min_val=1, max_val=5) + check_argument("griffin_lim_iters", c, restricted=True, min_val=10, max_val=1000) + + # normalization parameters + check_argument("signal_norm", c, restricted=True) + check_argument("symmetric_norm", c, restricted=True) + check_argument("max_norm", c, restricted=True, min_val=0.1, max_val=1000) + check_argument("clip_norm", c, restricted=True) + check_argument("mel_fmin", c, restricted=True, min_val=0.0, max_val=1000) + check_argument("mel_fmax", c, restricted=True, min_val=500.0, allow_none=True) + check_argument("spec_gain", c, restricted=True, min_val=1, max_val=100) + check_argument("do_trim_silence", c, restricted=True) + check_argument("trim_db", c, restricted=True) + + +@dataclass +class BaseDatasetConfig(Coqpit): + name: str = None + path: str = None + meta_file_train: str = None + meta_file_val: str = None + meta_file_attn_mask: str = None + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("name", c, restricted=True) + check_argument("path", c, restricted=True) + check_argument("meta_file_train", c, restricted=True) + check_argument("meta_file_val", c, restricted=False) + check_argument("meta_file_attn_mask", c, restricted=False) + + +@dataclass +class BaseTrainingConfig(Coqpit): + """Base config to define the basic training parameters that are shared + among all the models. + + Args: + batch_size (int): + Training batch size. + batch_group_size (int): + Number of batches to shuffle after bucketing. + eval_batch_size (int): + Validation batch size. + loss_masking (bool): + Enable / Disable masking padding segments of sequences. + mixed_precision (bool): + Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however + it may also cause numerical unstability in some cases. + run_eval (bool): + Enable / Disable evaluation (validation) run. Defaults to True. + test_delay_epochs (int): + Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful + results, hence waiting for a couple of epochs might save some time. + print_eval (bool): + Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at + the end of the evaluation. Default to ```False```. + print_step (int): + Number of steps required to print the next training log. + tb_plot_step (int): + Number of steps required to log training on Tensorboard. + tb_model_param_stats (bool): + Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging. + Defaults to ```False```. + save_step (int):ipt + Number of steps required to save the next checkpoint. + checkpoint (bool): + Enable / Disable checkpointing. + keep_all_best (bool): + Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults + to ```False```. + keep_after (int): + Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults + to 10000. + text_cleaner (str): + Text cleaner to be used at model training. It is set to be one of the cleaners in + ```TTS.tts.utils.text.cleaners```. + enable_eos_bos_chars (bool): + Enable / Disable using special characters indicating end-of-sentence and begining-of-sentence. + num_loader_workers (int): + Number of workers for training time dataloader. + num_val_loader_workers (int): + Number of workers for evaluation time dataloader. + min_seq_len (int): + Minimum sequence length to be used at training. + max_seq_len (int): + Maximum sequence length to be used at training. VRAM use at training depends on this parameter. Consider to + decrease it if you get OOM errors. + compute_f0 (bool): + Return F0 frames from the dataloader. Defaults to ```False```. + compute_input_seq_cache (bool): + Enable / Disable computing and caching phonemes sequences from character sequences at the begining of the + training. It allows faster data loading times and more precise max-min sequence prunning. Defaults + to ```False```. + output_path (str): + Path for training output folder. The nonexist part of the given path is created automatically. + All training outputs are saved there. + phoneme_cache_path (str): + Path to a folder to save the computed phoneme sequences. + datasets (List[BaseDatasetConfig]): + ist of DatasetConfig. + + """ + + model: str = None + run_name: str = "" + run_description: str = "" + # training params + epochs: int = 10000 + batch_size: int = MISSING + eval_batch_size: int = None + mixed_precision: bool = False + # eval params + run_eval: bool = True + test_delay_epochs: int = 0 + print_eval: bool = False + # logging + print_step: int = 25 + tb_plot_step: int = 100 + tb_model_param_stats: bool = False + # checkpointing + save_step: int = 10000 + checkpoint: bool = True + keep_all_best: bool = False + keep_after: int = 10000 + # dataloading + num_loader_workers: int = None + num_val_loader_workers: int = None + use_noise_augment: bool = False + # paths + output_path: str = None + # distributed + distributed_backend: str = "nccl" + distributed_url: str = "tcp://localhost:54321"