Bug fix on XTTS v1.1 inference (#3093)

* Bug fix on XTTS v1.1 inference

* Update .models.json

---------

Co-authored-by: Julian Weber <julian.weber@hotmail.fr>
This commit is contained in:
Edresson Casanova 2023-10-20 17:29:43 -03:00 committed by GitHub
parent 85e7323739
commit 59576fc0ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 11 deletions

View File

@ -18,12 +18,12 @@
"xtts_v1.1": {
"description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.",
"hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/hash.md5"
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/hash.md5"
],
"model_hash": "10163afc541dc86801b33d1f3217b456",
"model_hash": "ae9e4b39e095fd5728fe7f7931ec66ad",
"default_vocoder": null,
"commit": "82910a63",
"license": "CPML",

View File

@ -78,13 +78,13 @@ class XttsConfig(BaseTTSConfig):
)
# inference params
temperature: float = 0.2
temperature: float = 0.85
length_penalty: float = 1.0
repetition_penalty: float = 2.0
top_k: int = 50
top_p: float = 0.8
top_p: float = 0.85
cond_free_k: float = 2.0
diffusion_temperature: float = 1.0
num_gpt_outputs: int = 16
num_gpt_outputs: int = 1
decoder_iterations: int = 30
decoder_sampler: str = "ddim"

View File

@ -821,8 +821,6 @@ class Xtts(BaseTTS):
self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path)
self.init_models()
if eval:
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
@ -831,6 +829,13 @@ class Xtts(BaseTTS):
for key in list(checkpoint.keys()):
if key.split(".")[0] in ignore_keys:
del checkpoint[key]
# deal with v1 and v1.1. V1 has the init_gpt_for_inference keys, v1.1 do not
try:
self.load_state_dict(checkpoint, strict=strict)
except:
if eval:
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
self.load_state_dict(checkpoint, strict=strict)
if eval: