mirror of https://github.com/coqui-ai/TTS.git
Swap checkpoint_dir, checkpoint_path parameters in xtts
This commit is contained in:
parent
c1762660ed
commit
1432572b40
|
@ -731,8 +731,8 @@ class Xtts(BaseTTS):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
checkpoint_dir=None,
|
|
||||||
checkpoint_path=None,
|
checkpoint_path=None,
|
||||||
|
checkpoint_dir=None,
|
||||||
vocab_path=None,
|
vocab_path=None,
|
||||||
eval=True,
|
eval=True,
|
||||||
strict=True,
|
strict=True,
|
||||||
|
@ -744,8 +744,8 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (dict): The configuration dictionary for the model.
|
config (dict): The configuration dictionary for the model.
|
||||||
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
|
|
||||||
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
|
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
|
||||||
|
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
|
||||||
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
|
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
|
||||||
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
|
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
|
||||||
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
|
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
|
||||||
|
@ -753,7 +753,8 @@ class Xtts(BaseTTS):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
if checkpoint_dir is None and checkpoint_path:
|
||||||
|
checkpoint_dir = os.path.dirname(checkpoint_path)
|
||||||
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
||||||
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue