From ce202532cfe74e2e297e4109a80a3b125f54bd49 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Mon, 2 Dec 2024 16:54:11 +0100 Subject: [PATCH] fix(xtts): clearer error message when file given to checkpoint_dir --- TTS/tts/models/xtts.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 35de91e3..d780e2b3 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -2,6 +2,7 @@ import logging import os from dataclasses import dataclass from pathlib import Path +from typing import Optional import librosa import torch @@ -10,6 +11,7 @@ import torchaudio from coqpit import Coqpit from trainer.io import load_fsspec +from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support @@ -719,14 +721,14 @@ class Xtts(BaseTTS): def load_checkpoint( self, - config, - checkpoint_dir=None, - checkpoint_path=None, - vocab_path=None, - eval=True, - strict=True, - use_deepspeed=False, - speaker_file_path=None, + config: XttsConfig, + checkpoint_dir: Optional[str] = None, + checkpoint_path: Optional[str] = None, + vocab_path: Optional[str] = None, + eval: bool = True, + strict: bool = True, + use_deepspeed: bool = False, + speaker_file_path: Optional[str] = None, ): """ Loads a checkpoint from disk and initializes the model's state and tokenizer. @@ -742,7 +744,9 @@ class Xtts(BaseTTS): Returns: None """ - + if checkpoint_dir is not None and Path(checkpoint_dir).is_file(): + msg = f"You passed a file to `checkpoint_dir=`. Use `checkpoint_path={checkpoint_dir}` instead." + raise ValueError(msg) model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") if vocab_path is None: if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file():