mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #184 from idiap/xtts-error
fix(xtts): clearer error message when file given to checkpoint_dir
This commit is contained in:
commit
e8d99aaf2b
|
@ -5,7 +5,8 @@ from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
|
||||||
|
from TTS.tts.models.xtts import XttsAudioConfig
|
||||||
from TTS.utils.manage import ModelManager
|
from TTS.utils.manage import ModelManager
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||||
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
|
from TTS.tts.models.xtts import Xtts, XttsArgs
|
||||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -34,11 +34,6 @@ class GPTTrainerConfig(XttsConfig):
|
||||||
test_sentences: List[dict] = field(default_factory=lambda: [])
|
test_sentences: List[dict] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class XttsAudioConfig(XttsAudioConfig):
|
|
||||||
dvae_sample_rate: int = 22050
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPTArgs(XttsArgs):
|
class GPTArgs(XttsArgs):
|
||||||
min_conditioning_length: int = 66150
|
min_conditioning_length: int = 66150
|
||||||
|
|
|
@ -2,6 +2,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import torch
|
import torch
|
||||||
|
@ -101,10 +102,12 @@ class XttsAudioConfig(Coqpit):
|
||||||
Args:
|
Args:
|
||||||
sample_rate (int): The sample rate in which the GPT operates.
|
sample_rate (int): The sample rate in which the GPT operates.
|
||||||
output_sample_rate (int): The sample rate of the output audio waveform.
|
output_sample_rate (int): The sample rate of the output audio waveform.
|
||||||
|
dvae_sample_rate (int): The sample rate of the DVAE
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sample_rate: int = 22050
|
sample_rate: int = 22050
|
||||||
output_sample_rate: int = 24000
|
output_sample_rate: int = 24000
|
||||||
|
dvae_sample_rate: int = 22050
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -719,14 +722,14 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self,
|
self,
|
||||||
config,
|
config: "XttsConfig",
|
||||||
checkpoint_dir=None,
|
checkpoint_dir: Optional[str] = None,
|
||||||
checkpoint_path=None,
|
checkpoint_path: Optional[str] = None,
|
||||||
vocab_path=None,
|
vocab_path: Optional[str] = None,
|
||||||
eval=True,
|
eval: bool = True,
|
||||||
strict=True,
|
strict: bool = True,
|
||||||
use_deepspeed=False,
|
use_deepspeed: bool = False,
|
||||||
speaker_file_path=None,
|
speaker_file_path: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
||||||
|
@ -742,7 +745,9 @@ class Xtts(BaseTTS):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
if checkpoint_dir is not None and Path(checkpoint_dir).is_file():
|
||||||
|
msg = f"You passed a file to `checkpoint_dir=`. Use `checkpoint_path={checkpoint_dir}` instead."
|
||||||
|
raise ValueError(msg)
|
||||||
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
||||||
if vocab_path is None:
|
if vocab_path is None:
|
||||||
if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file():
|
if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file():
|
||||||
|
|
|
@ -4,7 +4,8 @@ from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
|
||||||
|
from TTS.tts.models.xtts import XttsAudioConfig
|
||||||
from TTS.utils.manage import ModelManager
|
from TTS.utils.manage import ModelManager
|
||||||
|
|
||||||
# Logging parameters
|
# Logging parameters
|
||||||
|
|
|
@ -4,7 +4,8 @@ from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
|
||||||
|
from TTS.tts.models.xtts import XttsAudioConfig
|
||||||
from TTS.utils.manage import ModelManager
|
from TTS.utils.manage import ModelManager
|
||||||
|
|
||||||
# Logging parameters
|
# Logging parameters
|
||||||
|
|
|
@ -8,7 +8,8 @@ from tests import get_tests_output_path
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
|
||||||
|
from TTS.tts.models.xtts import XttsAudioConfig
|
||||||
|
|
||||||
config_dataset = BaseDatasetConfig(
|
config_dataset = BaseDatasetConfig(
|
||||||
formatter="ljspeech",
|
formatter="ljspeech",
|
||||||
|
|
|
@ -8,7 +8,8 @@ from tests import get_tests_output_path
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
|
||||||
|
from TTS.tts.models.xtts import XttsAudioConfig
|
||||||
|
|
||||||
config_dataset = BaseDatasetConfig(
|
config_dataset = BaseDatasetConfig(
|
||||||
formatter="ljspeech",
|
formatter="ljspeech",
|
||||||
|
|
Loading…
Reference in New Issue