mirror of https://github.com/coqui-ai/TTS.git
85 lines
2.8 KiB
Python
85 lines
2.8 KiB
Python
from dataclasses import asdict, dataclass, field
|
|
from typing import List
|
|
|
|
from coqpit import MISSING, Coqpit, check_argument
|
|
|
|
from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
|
|
|
|
|
|
@dataclass
|
|
class GSTConfig(Coqpit):
|
|
"""Defines Global Style Toke module"""
|
|
|
|
gst_style_input_wav: str = None
|
|
gst_style_input_weights: dict = None
|
|
gst_embedding_dim: int = 256
|
|
gst_use_speaker_embedding: bool = False
|
|
gst_num_heads: int = 4
|
|
gst_num_style_tokens: int = 10
|
|
|
|
def check_values(
|
|
self,
|
|
):
|
|
"""Check config fields"""
|
|
c = asdict(self)
|
|
super().check_values()
|
|
check_argument("gst_style_input_weights", c, restricted=False)
|
|
check_argument("gst_style_input_wav", c, restricted=False)
|
|
check_argument("gst_embedding_dim", c, restricted=True, min_val=0, max_val=1000)
|
|
check_argument("gst_use_speaker_embedding", c, restricted=False)
|
|
check_argument("gst_num_heads", c, restricted=True, min_val=2, max_val=10)
|
|
check_argument("gst_num_style_tokens", c, restricted=True, min_val=1, max_val=1000)
|
|
|
|
|
|
@dataclass
|
|
class CharactersConfig(Coqpit):
|
|
"""Defines character or phoneme set used by the model"""
|
|
|
|
pad: str = None
|
|
eos: str = None
|
|
bos: str = None
|
|
characters: str = None
|
|
punctuations: str = None
|
|
phonemes: str = None
|
|
unique: bool = True # for backwards compatibility of models trained with char sets with duplicates
|
|
|
|
def check_values(
|
|
self,
|
|
):
|
|
"""Check config fields"""
|
|
c = asdict(self)
|
|
check_argument("pad", c, "characters", restricted=True)
|
|
check_argument("eos", c, "characters", restricted=True)
|
|
check_argument("bos", c, "characters", restricted=True)
|
|
check_argument("characters", c, "characters", restricted=True)
|
|
check_argument("phonemes", c, restricted=True)
|
|
check_argument("punctuations", c, "characters", restricted=True)
|
|
|
|
|
|
@dataclass
|
|
class BaseTTSConfig(BaseTrainingConfig):
|
|
"""Shared parameters among all the tts models."""
|
|
|
|
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
|
# phoneme settings
|
|
use_phonemes: bool = False
|
|
phoneme_language: str = None
|
|
compute_input_seq_cache: bool = False
|
|
text_cleaner: str = MISSING
|
|
enable_eos_bos_chars: bool = False
|
|
test_sentences_file: str = ""
|
|
phoneme_cache_path: str = None
|
|
# vocabulary parameters
|
|
characters: CharactersConfig = None
|
|
# training params
|
|
batch_group_size: int = 0
|
|
loss_masking: bool = None
|
|
# dataloading
|
|
min_seq_len: int = 1
|
|
max_seq_len: int = float("inf")
|
|
compute_f0: bool = False
|
|
use_noise_augment: bool = False
|
|
add_blank: bool = False
|
|
# dataset
|
|
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|