mirror of https://github.com/coqui-ai/TTS.git
refactor(xtts): remove duplicate xtts audio config
This commit is contained in:
parent
ce202532cf
commit
fe14ca6b68
|
@ -5,7 +5,8 @@ from trainer import Trainer, TrainerArgs
|
|||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
|
||||
from TTS.tts.models.xtts import XttsAudioConfig
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
|||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
|
||||
from TTS.tts.models.xtts import Xtts, XttsArgs
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -34,11 +34,6 @@ class GPTTrainerConfig(XttsConfig):
|
|||
test_sentences: List[dict] = field(default_factory=lambda: [])
|
||||
|
||||
|
||||
@dataclass
|
||||
class XttsAudioConfig(XttsAudioConfig):
|
||||
dvae_sample_rate: int = 22050
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTArgs(XttsArgs):
|
||||
min_conditioning_length: int = 66150
|
||||
|
|
|
@ -11,7 +11,6 @@ import torchaudio
|
|||
from coqpit import Coqpit
|
||||
from trainer.io import load_fsspec
|
||||
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||
|
@ -103,10 +102,12 @@ class XttsAudioConfig(Coqpit):
|
|||
Args:
|
||||
sample_rate (int): The sample rate in which the GPT operates.
|
||||
output_sample_rate (int): The sample rate of the output audio waveform.
|
||||
dvae_sample_rate (int): The sample rate of the DVAE
|
||||
"""
|
||||
|
||||
sample_rate: int = 22050
|
||||
output_sample_rate: int = 24000
|
||||
dvae_sample_rate: int = 22050
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -721,7 +722,7 @@ class Xtts(BaseTTS):
|
|||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
config: XttsConfig,
|
||||
config: "XttsConfig",
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
vocab_path: Optional[str] = None,
|
||||
|
|
|
@ -4,7 +4,8 @@ from trainer import Trainer, TrainerArgs
|
|||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
|
||||
from TTS.tts.models.xtts import XttsAudioConfig
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
# Logging parameters
|
||||
|
|
|
@ -4,7 +4,8 @@ from trainer import Trainer, TrainerArgs
|
|||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
|
||||
from TTS.tts.models.xtts import XttsAudioConfig
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
# Logging parameters
|
||||
|
|
|
@ -8,7 +8,8 @@ from tests import get_tests_output_path
|
|||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
|
||||
from TTS.tts.models.xtts import XttsAudioConfig
|
||||
|
||||
config_dataset = BaseDatasetConfig(
|
||||
formatter="ljspeech",
|
||||
|
|
|
@ -8,7 +8,8 @@ from tests import get_tests_output_path
|
|||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
|
||||
from TTS.tts.models.xtts import XttsAudioConfig
|
||||
|
||||
config_dataset = BaseDatasetConfig(
|
||||
formatter="ljspeech",
|
||||
|
|
Loading…
Reference in New Issue