diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 2b480744..9b534d53 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -767,6 +767,7 @@ class Xtts(BaseTTS): eval=True, strict=True, use_deepspeed=False, + autoregressive_model_path=None, ): """ Loads a checkpoint from disk and initializes the model's state and tokenizer. @@ -798,6 +799,9 @@ class Xtts(BaseTTS): for key in list(checkpoint.keys()): if key.split(".")[0] in ignore_keys: del checkpoint[key] + if autoregressive_model_path is not None: + checkpoint["gpt"] = torch.load(autoregressive_model_path, map_location=torch.device("cpu")) + self.load_state_dict(checkpoint, strict=strict) if eval: