diff --git a/TTS/.models.json b/TTS/.models.json index 8e35893b..0c318740 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -18,12 +18,12 @@ "xtts_v1.1": { "description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.", "hf_url": [ - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/model.pth", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/config.json", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/vocab.json", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/hash.md5" + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/config.json", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/hash.md5" ], - "model_hash": "10163afc541dc86801b33d1f3217b456", + "model_hash": "ae9e4b39e095fd5728fe7f7931ec66ad", "default_vocoder": null, "commit": "82910a63", "license": "CPML", diff --git a/TTS/tts/configs/xtts_config.py b/TTS/tts/configs/xtts_config.py index b9685590..4e5031ba 100644 --- a/TTS/tts/configs/xtts_config.py +++ b/TTS/tts/configs/xtts_config.py @@ -78,13 +78,13 @@ class XttsConfig(BaseTTSConfig): ) # inference params - temperature: float = 0.2 + temperature: float = 0.85 length_penalty: float = 1.0 repetition_penalty: float = 2.0 top_k: int = 50 - top_p: float = 0.8 + top_p: float = 0.85 cond_free_k: float = 2.0 diffusion_temperature: float = 1.0 - num_gpt_outputs: int = 16 + num_gpt_outputs: int = 1 decoder_iterations: int = 30 decoder_sampler: str = "ddim" diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 76c5595e..40e8f946 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -821,8 +821,6 @@ class Xtts(BaseTTS): self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path) self.init_models() - if eval: - self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"] ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else [] @@ -831,7 +829,14 @@ class Xtts(BaseTTS): for key in list(checkpoint.keys()): if key.split(".")[0] in ignore_keys: del checkpoint[key] - self.load_state_dict(checkpoint, strict=strict) + + # deal with v1 and v1.1. V1 has the init_gpt_for_inference keys, v1.1 do not + try: + self.load_state_dict(checkpoint, strict=strict) + except: + if eval: + self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) + self.load_state_dict(checkpoint, strict=strict) if eval: if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()