mirror of https://github.com/coqui-ai/TTS.git
xtts: allow loading custom autoregressive model
This commit is contained in:
parent
99635193f5
commit
403b3047cb
|
@ -767,6 +767,7 @@ class Xtts(BaseTTS):
|
|||
eval=True,
|
||||
strict=True,
|
||||
use_deepspeed=False,
|
||||
autoregressive_model_path=None,
|
||||
):
|
||||
"""
|
||||
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
||||
|
@ -798,6 +799,9 @@ class Xtts(BaseTTS):
|
|||
for key in list(checkpoint.keys()):
|
||||
if key.split(".")[0] in ignore_keys:
|
||||
del checkpoint[key]
|
||||
if autoregressive_model_path is not None:
|
||||
checkpoint["gpt"] = torch.load(autoregressive_model_path, map_location=torch.device("cpu"))
|
||||
|
||||
self.load_state_dict(checkpoint, strict=strict)
|
||||
|
||||
if eval:
|
||||
|
|
Loading…
Reference in New Issue