Merge pull request #184 from idiap/xtts-error

fix(xtts): clearer error message when file given to checkpoint_dir
This commit is contained in:
Enno Hermann 2024-12-06 06:46:48 +01:00 committed by GitHub
commit e8d99aaf2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 25 additions and 20 deletions

View File

@ -5,7 +5,8 @@ from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
from TTS.utils.manage import ModelManager from TTS.utils.manage import ModelManager

View File

@ -18,7 +18,7 @@ from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig from TTS.tts.models.xtts import Xtts, XttsArgs
from TTS.utils.generic_utils import is_pytorch_at_least_2_4 from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,11 +34,6 @@ class GPTTrainerConfig(XttsConfig):
test_sentences: List[dict] = field(default_factory=lambda: []) test_sentences: List[dict] = field(default_factory=lambda: [])
@dataclass
class XttsAudioConfig(XttsAudioConfig):
dvae_sample_rate: int = 22050
@dataclass @dataclass
class GPTArgs(XttsArgs): class GPTArgs(XttsArgs):
min_conditioning_length: int = 66150 min_conditioning_length: int = 66150

View File

@ -2,6 +2,7 @@ import logging
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional
import librosa import librosa
import torch import torch
@ -101,10 +102,12 @@ class XttsAudioConfig(Coqpit):
Args: Args:
sample_rate (int): The sample rate in which the GPT operates. sample_rate (int): The sample rate in which the GPT operates.
output_sample_rate (int): The sample rate of the output audio waveform. output_sample_rate (int): The sample rate of the output audio waveform.
dvae_sample_rate (int): The sample rate of the DVAE
""" """
sample_rate: int = 22050 sample_rate: int = 22050
output_sample_rate: int = 24000 output_sample_rate: int = 24000
dvae_sample_rate: int = 22050
@dataclass @dataclass
@ -719,14 +722,14 @@ class Xtts(BaseTTS):
def load_checkpoint( def load_checkpoint(
self, self,
config, config: "XttsConfig",
checkpoint_dir=None, checkpoint_dir: Optional[str] = None,
checkpoint_path=None, checkpoint_path: Optional[str] = None,
vocab_path=None, vocab_path: Optional[str] = None,
eval=True, eval: bool = True,
strict=True, strict: bool = True,
use_deepspeed=False, use_deepspeed: bool = False,
speaker_file_path=None, speaker_file_path: Optional[str] = None,
): ):
""" """
Loads a checkpoint from disk and initializes the model's state and tokenizer. Loads a checkpoint from disk and initializes the model's state and tokenizer.
@ -742,7 +745,9 @@ class Xtts(BaseTTS):
Returns: Returns:
None None
""" """
if checkpoint_dir is not None and Path(checkpoint_dir).is_file():
msg = f"You passed a file to `checkpoint_dir=`. Use `checkpoint_path={checkpoint_dir}` instead."
raise ValueError(msg)
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
if vocab_path is None: if vocab_path is None:
if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file(): if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file():

View File

@ -4,7 +4,8 @@ from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
from TTS.utils.manage import ModelManager from TTS.utils.manage import ModelManager
# Logging parameters # Logging parameters

View File

@ -4,7 +4,8 @@ from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
from TTS.utils.manage import ModelManager from TTS.utils.manage import ModelManager
# Logging parameters # Logging parameters

View File

@ -8,7 +8,8 @@ from tests import get_tests_output_path
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.dvae import DiscreteVAE from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
config_dataset = BaseDatasetConfig( config_dataset = BaseDatasetConfig(
formatter="ljspeech", formatter="ljspeech",

View File

@ -8,7 +8,8 @@ from tests import get_tests_output_path
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.dvae import DiscreteVAE from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
config_dataset = BaseDatasetConfig( config_dataset = BaseDatasetConfig(
formatter="ljspeech", formatter="ljspeech",