mirror of https://github.com/coqui-ai/TTS.git
xtts: allow loading custom autoregressive model
This commit is contained in:
parent
99635193f5
commit
403b3047cb
|
@ -767,6 +767,7 @@ class Xtts(BaseTTS):
|
||||||
eval=True,
|
eval=True,
|
||||||
strict=True,
|
strict=True,
|
||||||
use_deepspeed=False,
|
use_deepspeed=False,
|
||||||
|
autoregressive_model_path=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
||||||
|
@ -798,6 +799,9 @@ class Xtts(BaseTTS):
|
||||||
for key in list(checkpoint.keys()):
|
for key in list(checkpoint.keys()):
|
||||||
if key.split(".")[0] in ignore_keys:
|
if key.split(".")[0] in ignore_keys:
|
||||||
del checkpoint[key]
|
del checkpoint[key]
|
||||||
|
if autoregressive_model_path is not None:
|
||||||
|
checkpoint["gpt"] = torch.load(autoregressive_model_path, map_location=torch.device("cpu"))
|
||||||
|
|
||||||
self.load_state_dict(checkpoint, strict=strict)
|
self.load_state_dict(checkpoint, strict=strict)
|
||||||
|
|
||||||
if eval:
|
if eval:
|
||||||
|
|
Loading…
Reference in New Issue