From 403b3047cbdeab80e4aa2fb258896b33fa4bd879 Mon Sep 17 00:00:00 2001 From: Emmanuel Schmidbauer Date: Thu, 12 Oct 2023 12:57:19 -0400 Subject: [PATCH] xtts: allow loading custom autoregressive model --- TTS/tts/models/xtts.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 2b480744..9b534d53 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -767,6 +767,7 @@ class Xtts(BaseTTS): eval=True, strict=True, use_deepspeed=False, + autoregressive_model_path=None, ): """ Loads a checkpoint from disk and initializes the model's state and tokenizer. @@ -798,6 +799,9 @@ class Xtts(BaseTTS): for key in list(checkpoint.keys()): if key.split(".")[0] in ignore_keys: del checkpoint[key] + if autoregressive_model_path is not None: + checkpoint["gpt"] = torch.load(autoregressive_model_path, map_location=torch.device("cpu")) + self.load_state_dict(checkpoint, strict=strict) if eval: