Swap checkpoint_dir, checkpoint_path parameters in xtts

This commit is contained in:
bivashy 2023-12-30 00:14:48 +06:00
parent c1762660ed
commit 1432572b40
No known key found for this signature in database
GPG Key ID: AF03FBF1D67CAA32
1 changed files with 4 additions and 3 deletions

View File

@ -731,8 +731,8 @@ class Xtts(BaseTTS):
def load_checkpoint(
self,
config,
checkpoint_dir=None,
checkpoint_path=None,
checkpoint_dir=None,
vocab_path=None,
eval=True,
strict=True,
@ -744,8 +744,8 @@ class Xtts(BaseTTS):
Args:
config (dict): The configuration dictionary for the model.
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
@ -753,7 +753,8 @@ class Xtts(BaseTTS):
Returns:
None
"""
if checkpoint_dir is None and checkpoint_path:
checkpoint_dir = os.path.dirname(checkpoint_path)
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")