fix: xtts not taking into account device flag (#2951)

* fix: xtts not taking into account device flag

* Style changes

---------

Co-authored-by: Julian Weber <julian.weber@hotmail.fr>
This commit is contained in:
loupzeur 2023-09-20 09:57:02 +02:00 committed by GitHub
parent 335ae63e01
commit da8b6bbce1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -642,7 +642,7 @@ class Xtts(BaseTTS):
self.init_models()
if eval:
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
self.load_state_dict(load_fsspec(model_path)["model"], strict=strict)
self.load_state_dict(load_fsspec(model_path, map_location=self.device)["model"], strict=strict)
if eval:
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)