From fe14ca6b68f8757f581ec04d2d0becddd7031d05 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 5 Dec 2024 15:38:50 +0100 Subject: [PATCH] refactor(xtts): remove duplicate xtts audio config --- TTS/demos/xtts_ft_demo/utils/gpt_train.py | 3 ++- TTS/tts/layers/xtts/trainer/gpt_trainer.py | 7 +------ TTS/tts/models/xtts.py | 5 +++-- recipes/ljspeech/xtts_v1/train_gpt_xtts.py | 3 ++- recipes/ljspeech/xtts_v2/train_gpt_xtts.py | 3 ++- tests/xtts_tests/test_xtts_gpt_train.py | 3 ++- tests/xtts_tests/test_xtts_v2-0_gpt_train.py | 3 ++- 7 files changed, 14 insertions(+), 13 deletions(-) diff --git a/TTS/demos/xtts_ft_demo/utils/gpt_train.py b/TTS/demos/xtts_ft_demo/utils/gpt_train.py index f838297a..411a9b0d 100644 --- a/TTS/demos/xtts_ft_demo/utils/gpt_train.py +++ b/TTS/demos/xtts_ft_demo/utils/gpt_train.py @@ -5,7 +5,8 @@ from trainer import Trainer, TrainerArgs from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig +from TTS.tts.models.xtts import XttsAudioConfig from TTS.utils.manage import ModelManager diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 0253d65d..10705418 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -18,7 +18,7 @@ from TTS.tts.layers.xtts.dvae import DiscreteVAE from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig +from TTS.tts.models.xtts import Xtts, XttsArgs from TTS.utils.generic_utils import is_pytorch_at_least_2_4 logger = logging.getLogger(__name__) @@ -34,11 +34,6 @@ class GPTTrainerConfig(XttsConfig): test_sentences: List[dict] = field(default_factory=lambda: []) -@dataclass -class XttsAudioConfig(XttsAudioConfig): - dvae_sample_rate: int = 22050 - - @dataclass class GPTArgs(XttsArgs): min_conditioning_length: int = 66150 diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index d780e2b3..f05863ae 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -11,7 +11,6 @@ import torchaudio from coqpit import Coqpit from trainer.io import load_fsspec -from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support @@ -103,10 +102,12 @@ class XttsAudioConfig(Coqpit): Args: sample_rate (int): The sample rate in which the GPT operates. output_sample_rate (int): The sample rate of the output audio waveform. + dvae_sample_rate (int): The sample rate of the DVAE """ sample_rate: int = 22050 output_sample_rate: int = 24000 + dvae_sample_rate: int = 22050 @dataclass @@ -721,7 +722,7 @@ class Xtts(BaseTTS): def load_checkpoint( self, - config: XttsConfig, + config: "XttsConfig", checkpoint_dir: Optional[str] = None, checkpoint_path: Optional[str] = None, vocab_path: Optional[str] = None, diff --git a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py index d31ec8f1..a077a180 100644 --- a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py @@ -4,7 +4,8 @@ from trainer import Trainer, TrainerArgs from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig +from TTS.tts.models.xtts import XttsAudioConfig from TTS.utils.manage import ModelManager # Logging parameters diff --git a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py index ccaa97f1..362f4500 100644 --- a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py @@ -4,7 +4,8 @@ from trainer import Trainer, TrainerArgs from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig +from TTS.tts.models.xtts import XttsAudioConfig from TTS.utils.manage import ModelManager # Logging parameters diff --git a/tests/xtts_tests/test_xtts_gpt_train.py b/tests/xtts_tests/test_xtts_gpt_train.py index b8b9a4e3..bb592f1f 100644 --- a/tests/xtts_tests/test_xtts_gpt_train.py +++ b/tests/xtts_tests/test_xtts_gpt_train.py @@ -8,7 +8,8 @@ from tests import get_tests_output_path from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.layers.xtts.dvae import DiscreteVAE -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig +from TTS.tts.models.xtts import XttsAudioConfig config_dataset = BaseDatasetConfig( formatter="ljspeech", diff --git a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py index 6663433c..454e8673 100644 --- a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py +++ b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py @@ -8,7 +8,8 @@ from tests import get_tests_output_path from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.layers.xtts.dvae import DiscreteVAE -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig +from TTS.tts.models.xtts import XttsAudioConfig config_dataset = BaseDatasetConfig( formatter="ljspeech",