Fix Tortoise load (#2697)

* Handle missing gpt weights

* Make style

* Fix lint
This commit is contained in:
Eren Gölge 2023-06-21 15:42:01 +02:00 committed by GitHub
parent d65819422b
commit 4cf8652392
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 1 deletions

View File

@ -1,5 +1,6 @@
import os import os
import random import random
import re
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from time import time from time import time
@ -871,7 +872,16 @@ class Tortoise(BaseTTS):
vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth") vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth")
if os.path.exists(ar_path): if os.path.exists(ar_path):
self.autoregressive.load_state_dict(torch.load(ar_path), strict=strict) keys_to_ignore = self.autoregressive.gpt._keys_to_ignore_on_load_missing # pylint: disable=protected-access
# remove keys from the checkpoint that are not in the model
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"))
for key in list(checkpoint.keys()):
for pat in keys_to_ignore:
if re.search(pat, key) is not None:
del checkpoint[key]
break
self.autoregressive.load_state_dict(checkpoint, strict=strict)
if os.path.exists(diff_path): if os.path.exists(diff_path):
self.diffusion.load_state_dict(torch.load(diff_path), strict=strict) self.diffusion.load_state_dict(torch.load(diff_path), strict=strict)