mirror of https://github.com/coqui-ai/TTS.git
fix(xtts): clearer error message when file given to checkpoint_dir
This commit is contained in:
parent
98a372bca2
commit
ce202532cf
|
@ -2,6 +2,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import torch
|
import torch
|
||||||
|
@ -10,6 +11,7 @@ import torchaudio
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from trainer.io import load_fsspec
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
|
from TTS.tts.configs.xtts_config import XttsConfig
|
||||||
from TTS.tts.layers.xtts.gpt import GPT
|
from TTS.tts.layers.xtts.gpt import GPT
|
||||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||||
|
@ -719,14 +721,14 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self,
|
self,
|
||||||
config,
|
config: XttsConfig,
|
||||||
checkpoint_dir=None,
|
checkpoint_dir: Optional[str] = None,
|
||||||
checkpoint_path=None,
|
checkpoint_path: Optional[str] = None,
|
||||||
vocab_path=None,
|
vocab_path: Optional[str] = None,
|
||||||
eval=True,
|
eval: bool = True,
|
||||||
strict=True,
|
strict: bool = True,
|
||||||
use_deepspeed=False,
|
use_deepspeed: bool = False,
|
||||||
speaker_file_path=None,
|
speaker_file_path: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
||||||
|
@ -742,7 +744,9 @@ class Xtts(BaseTTS):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
if checkpoint_dir is not None and Path(checkpoint_dir).is_file():
|
||||||
|
msg = f"You passed a file to `checkpoint_dir=`. Use `checkpoint_path={checkpoint_dir}` instead."
|
||||||
|
raise ValueError(msg)
|
||||||
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
||||||
if vocab_path is None:
|
if vocab_path is None:
|
||||||
if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file():
|
if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file():
|
||||||
|
|
Loading…
Reference in New Issue