mirror of https://github.com/coqui-ai/TTS.git
71 lines
2.6 KiB
Python
71 lines
2.6 KiB
Python
import os
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict
|
|
|
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
|
from TTS.tts.layers.bark.model import GPTConfig
|
|
from TTS.tts.layers.bark.model_fine import FineGPTConfig
|
|
from TTS.tts.models.bark import BarkAudioConfig
|
|
from TTS.utils.generic_utils import get_user_data_dir
|
|
|
|
|
|
@dataclass
|
|
class BarkConfig(BaseTTSConfig):
|
|
model: str = "bark"
|
|
audio: BarkAudioConfig = BarkAudioConfig()
|
|
num_chars: int = 0
|
|
semantic_config: GPTConfig = GPTConfig()
|
|
fine_config: FineGPTConfig = FineGPTConfig()
|
|
coarse_config: GPTConfig = GPTConfig()
|
|
CONTEXT_WINDOW_SIZE: int = 1024
|
|
SEMANTIC_RATE_HZ: float = 49.9
|
|
SEMANTIC_VOCAB_SIZE: int = 10_000
|
|
CODEBOOK_SIZE: int = 1024
|
|
N_COARSE_CODEBOOKS: int = 2
|
|
N_FINE_CODEBOOKS: int = 8
|
|
COARSE_RATE_HZ: int = 75
|
|
SAMPLE_RATE: int = 24_000
|
|
USE_SMALLER_MODELS: bool = False
|
|
|
|
TEXT_ENCODING_OFFSET: int = 10_048
|
|
SEMANTIC_PAD_TOKEN: int = 10_000
|
|
TEXT_PAD_TOKEN: int = 129_595
|
|
SEMANTIC_INFER_TOKEN: int = 129_599
|
|
COARSE_SEMANTIC_PAD_TOKEN: int = 12_048
|
|
COARSE_INFER_TOKEN: int = 12_050
|
|
|
|
REMOTE_BASE_URL = "https://huggingface.co/erogol/bark/tree/main/"
|
|
REMOTE_MODEL_PATHS: Dict = None
|
|
LOCAL_MODEL_PATHS: Dict = None
|
|
SMALL_REMOTE_MODEL_PATHS: Dict = None
|
|
CACHE_DIR: str = str(get_user_data_dir("tts/suno/bark_v0"))
|
|
|
|
def __post_init__(self):
|
|
self.REMOTE_MODEL_PATHS = {
|
|
"text": {
|
|
"path": os.path.join(self.REMOTE_BASE_URL, "text_2.pt"),
|
|
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
|
|
},
|
|
"coarse": {
|
|
"path": os.path.join(self.REMOTE_BASE_URL, "coarse_2.pt"),
|
|
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
|
|
},
|
|
"fine": {
|
|
"path": os.path.join(self.REMOTE_BASE_URL, "fine_2.pt"),
|
|
"checksum": "59d184ed44e3650774a2f0503a48a97b",
|
|
},
|
|
}
|
|
self.LOCAL_MODEL_PATHS = {
|
|
"text": os.path.join(self.CACHE_DIR, "text_2.pt"),
|
|
"coarse": os.path.join(self.CACHE_DIR, "coarse_2.pt"),
|
|
"fine": os.path.join(self.CACHE_DIR, "fine_2.pt"),
|
|
"hubert_tokenizer": os.path.join(self.CACHE_DIR, "tokenizer.pth"),
|
|
"hubert": os.path.join(self.CACHE_DIR, "hubert.pt"),
|
|
}
|
|
self.SMALL_REMOTE_MODEL_PATHS = {
|
|
"text": {"path": os.path.join(self.REMOTE_BASE_URL, "text.pt")},
|
|
"coarse": {"path": os.path.join(self.REMOTE_BASE_URL, "coarse.pt")},
|
|
"fine": {"path": os.path.join(self.REMOTE_BASE_URL, "fine.pt")},
|
|
}
|
|
self.sample_rate = self.SAMPLE_RATE
|