refactor(xtts): remove duplicate xtts audio config

This commit is contained in:
Enno Hermann 2024-12-05 15:38:50 +01:00
parent ce202532cf
commit fe14ca6b68
7 changed files with 14 additions and 13 deletions

View File

@ -5,7 +5,8 @@ from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
from TTS.utils.manage import ModelManager from TTS.utils.manage import ModelManager

View File

@ -18,7 +18,7 @@ from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig from TTS.tts.models.xtts import Xtts, XttsArgs
from TTS.utils.generic_utils import is_pytorch_at_least_2_4 from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,11 +34,6 @@ class GPTTrainerConfig(XttsConfig):
test_sentences: List[dict] = field(default_factory=lambda: []) test_sentences: List[dict] = field(default_factory=lambda: [])
@dataclass
class XttsAudioConfig(XttsAudioConfig):
dvae_sample_rate: int = 22050
@dataclass @dataclass
class GPTArgs(XttsArgs): class GPTArgs(XttsArgs):
min_conditioning_length: int = 66150 min_conditioning_length: int = 66150

View File

@ -11,7 +11,6 @@ import torchaudio
from coqpit import Coqpit from coqpit import Coqpit
from trainer.io import load_fsspec from trainer.io import load_fsspec
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.layers.xtts.stream_generator import init_stream_support
@ -103,10 +102,12 @@ class XttsAudioConfig(Coqpit):
Args: Args:
sample_rate (int): The sample rate in which the GPT operates. sample_rate (int): The sample rate in which the GPT operates.
output_sample_rate (int): The sample rate of the output audio waveform. output_sample_rate (int): The sample rate of the output audio waveform.
dvae_sample_rate (int): The sample rate of the DVAE
""" """
sample_rate: int = 22050 sample_rate: int = 22050
output_sample_rate: int = 24000 output_sample_rate: int = 24000
dvae_sample_rate: int = 22050
@dataclass @dataclass
@ -721,7 +722,7 @@ class Xtts(BaseTTS):
def load_checkpoint( def load_checkpoint(
self, self,
config: XttsConfig, config: "XttsConfig",
checkpoint_dir: Optional[str] = None, checkpoint_dir: Optional[str] = None,
checkpoint_path: Optional[str] = None, checkpoint_path: Optional[str] = None,
vocab_path: Optional[str] = None, vocab_path: Optional[str] = None,

View File

@ -4,7 +4,8 @@ from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
from TTS.utils.manage import ModelManager from TTS.utils.manage import ModelManager
# Logging parameters # Logging parameters

View File

@ -4,7 +4,8 @@ from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
from TTS.utils.manage import ModelManager from TTS.utils.manage import ModelManager
# Logging parameters # Logging parameters

View File

@ -8,7 +8,8 @@ from tests import get_tests_output_path
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.dvae import DiscreteVAE from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
config_dataset = BaseDatasetConfig( config_dataset = BaseDatasetConfig(
formatter="ljspeech", formatter="ljspeech",

View File

@ -8,7 +8,8 @@ from tests import get_tests_output_path
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.dvae import DiscreteVAE from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
config_dataset = BaseDatasetConfig( config_dataset = BaseDatasetConfig(
formatter="ljspeech", formatter="ljspeech",