xtts: allow loading custom autoregressive model

This commit is contained in:
Emmanuel Schmidbauer 2023-10-12 12:57:19 -04:00
parent 99635193f5
commit 403b3047cb
1 changed files with 4 additions and 0 deletions

View File

@ -767,6 +767,7 @@ class Xtts(BaseTTS):
eval=True,
strict=True,
use_deepspeed=False,
autoregressive_model_path=None,
):
"""
Loads a checkpoint from disk and initializes the model's state and tokenizer.
@ -798,6 +799,9 @@ class Xtts(BaseTTS):
for key in list(checkpoint.keys()):
if key.split(".")[0] in ignore_keys:
del checkpoint[key]
if autoregressive_model_path is not None:
checkpoint["gpt"] = torch.load(autoregressive_model_path, map_location=torch.device("cpu"))
self.load_state_dict(checkpoint, strict=strict)
if eval: