xtts: allow loading custom autoregressive model

This commit is contained in:
Emmanuel Schmidbauer 2023-10-12 12:57:19 -04:00
parent 99635193f5
commit 403b3047cb
1 changed files with 4 additions and 0 deletions

View File

@ -767,6 +767,7 @@ class Xtts(BaseTTS):
eval=True, eval=True,
strict=True, strict=True,
use_deepspeed=False, use_deepspeed=False,
autoregressive_model_path=None,
): ):
""" """
Loads a checkpoint from disk and initializes the model's state and tokenizer. Loads a checkpoint from disk and initializes the model's state and tokenizer.
@ -798,6 +799,9 @@ class Xtts(BaseTTS):
for key in list(checkpoint.keys()): for key in list(checkpoint.keys()):
if key.split(".")[0] in ignore_keys: if key.split(".")[0] in ignore_keys:
del checkpoint[key] del checkpoint[key]
if autoregressive_model_path is not None:
checkpoint["gpt"] = torch.load(autoregressive_model_path, map_location=torch.device("cpu"))
self.load_state_dict(checkpoint, strict=strict) self.load_state_dict(checkpoint, strict=strict)
if eval: if eval: