mirror of https://github.com/coqui-ai/TTS.git
Swap checkpoint_dir, checkpoint_path parameters in xtts
This commit is contained in:
parent
c1762660ed
commit
1432572b40
|
@ -731,8 +731,8 @@ class Xtts(BaseTTS):
|
|||
def load_checkpoint(
|
||||
self,
|
||||
config,
|
||||
checkpoint_dir=None,
|
||||
checkpoint_path=None,
|
||||
checkpoint_dir=None,
|
||||
vocab_path=None,
|
||||
eval=True,
|
||||
strict=True,
|
||||
|
@ -744,8 +744,8 @@ class Xtts(BaseTTS):
|
|||
|
||||
Args:
|
||||
config (dict): The configuration dictionary for the model.
|
||||
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
|
||||
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
|
||||
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
|
||||
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
|
||||
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
|
||||
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
|
||||
|
@ -753,7 +753,8 @@ class Xtts(BaseTTS):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
if checkpoint_dir is None and checkpoint_path:
|
||||
checkpoint_dir = os.path.dirname(checkpoint_path)
|
||||
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
||||
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
||||
|
||||
|
|
Loading…
Reference in New Issue