mirror of https://github.com/coqui-ai/TTS.git
Bug fix on XTTS v1.1 inference
This commit is contained in:
parent
85e7323739
commit
a430360c1b
|
@ -78,13 +78,13 @@ class XttsConfig(BaseTTSConfig):
|
|||
)
|
||||
|
||||
# inference params
|
||||
temperature: float = 0.2
|
||||
temperature: float = 0.85
|
||||
length_penalty: float = 1.0
|
||||
repetition_penalty: float = 2.0
|
||||
top_k: int = 50
|
||||
top_p: float = 0.8
|
||||
top_p: float = 0.85
|
||||
cond_free_k: float = 2.0
|
||||
diffusion_temperature: float = 1.0
|
||||
num_gpt_outputs: int = 16
|
||||
num_gpt_outputs: int = 1
|
||||
decoder_iterations: int = 30
|
||||
decoder_sampler: str = "ddim"
|
||||
|
|
|
@ -821,8 +821,6 @@ class Xtts(BaseTTS):
|
|||
self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path)
|
||||
|
||||
self.init_models()
|
||||
if eval:
|
||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
|
||||
|
||||
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
|
||||
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
|
||||
|
@ -831,7 +829,14 @@ class Xtts(BaseTTS):
|
|||
for key in list(checkpoint.keys()):
|
||||
if key.split(".")[0] in ignore_keys:
|
||||
del checkpoint[key]
|
||||
self.load_state_dict(checkpoint, strict=strict)
|
||||
|
||||
# deal with v1 and v1.1. V1 has the init_gpt_for_inference keys, v1.1 do not
|
||||
try:
|
||||
self.load_state_dict(checkpoint, strict=strict)
|
||||
except:
|
||||
if eval:
|
||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
|
||||
self.load_state_dict(checkpoint, strict=strict)
|
||||
|
||||
if eval:
|
||||
if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
|
||||
|
|
Loading…
Reference in New Issue