fix(xtts): clearer error message when file given to checkpoint_dir

This commit is contained in:
Enno Hermann 2024-12-02 16:54:11 +01:00
parent 98a372bca2
commit ce202532cf
1 changed files with 13 additions and 9 deletions

View File

@ -2,6 +2,7 @@ import logging
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional
import librosa import librosa
import torch import torch
@ -10,6 +11,7 @@ import torchaudio
from coqpit import Coqpit from coqpit import Coqpit
from trainer.io import load_fsspec from trainer.io import load_fsspec
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.layers.xtts.stream_generator import init_stream_support
@ -719,14 +721,14 @@ class Xtts(BaseTTS):
def load_checkpoint( def load_checkpoint(
self, self,
config, config: XttsConfig,
checkpoint_dir=None, checkpoint_dir: Optional[str] = None,
checkpoint_path=None, checkpoint_path: Optional[str] = None,
vocab_path=None, vocab_path: Optional[str] = None,
eval=True, eval: bool = True,
strict=True, strict: bool = True,
use_deepspeed=False, use_deepspeed: bool = False,
speaker_file_path=None, speaker_file_path: Optional[str] = None,
): ):
""" """
Loads a checkpoint from disk and initializes the model's state and tokenizer. Loads a checkpoint from disk and initializes the model's state and tokenizer.
@ -742,7 +744,9 @@ class Xtts(BaseTTS):
Returns: Returns:
None None
""" """
if checkpoint_dir is not None and Path(checkpoint_dir).is_file():
msg = f"You passed a file to `checkpoint_dir=`. Use `checkpoint_path={checkpoint_dir}` instead."
raise ValueError(msg)
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
if vocab_path is None: if vocab_path is None:
if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file(): if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file():