From 1432572b40ac4cd585447c23b4d8a832717bb1da Mon Sep 17 00:00:00 2001 From: bivashy Date: Sat, 30 Dec 2023 00:14:48 +0600 Subject: [PATCH] Swap checkpoint_dir, checkpoint_path parameters in xtts --- TTS/tts/models/xtts.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 8e9d6bd3..6b4cdb8a 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -731,8 +731,8 @@ class Xtts(BaseTTS): def load_checkpoint( self, config, - checkpoint_dir=None, checkpoint_path=None, + checkpoint_dir=None, vocab_path=None, eval=True, strict=True, @@ -744,8 +744,8 @@ class Xtts(BaseTTS): Args: config (dict): The configuration dictionary for the model. - checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None. checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None. + checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None. vocab_path (str, optional): The path to the vocabulary file. Defaults to None. eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True. strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True. @@ -753,7 +753,8 @@ class Xtts(BaseTTS): Returns: None """ - + if checkpoint_dir is None and checkpoint_path: + checkpoint_dir = os.path.dirname(checkpoint_path) model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")