Fix Tortoise load (#2791)

* Remove key prunning in tortoise

* Make lint
This commit is contained in:
Eren Gölge 2023-07-24 13:42:47 +02:00 committed by GitHub
parent b3472a739e
commit 8aacb81849
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 11 deletions

View File

@ -1,6 +1,5 @@
import os import os
import random import random
import re
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from time import time from time import time
@ -876,16 +875,12 @@ class Tortoise(BaseTTS):
vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth") vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth")
if os.path.exists(ar_path): if os.path.exists(ar_path):
keys_to_ignore = self.autoregressive.gpt._keys_to_ignore_on_load_missing # pylint: disable=protected-access
# remove keys from the checkpoint that are not in the model # remove keys from the checkpoint that are not in the model
checkpoint = torch.load(ar_path, map_location=torch.device("cpu")) checkpoint = torch.load(ar_path, map_location=torch.device("cpu"))
for key in list(checkpoint.keys()):
for pat in keys_to_ignore:
if re.search(pat, key) is not None:
del checkpoint[key]
break
self.autoregressive.load_state_dict(checkpoint, strict=strict) # strict set False
# due to removed `bias` and `masked_bias` changes in Transformers
self.autoregressive.load_state_dict(checkpoint, strict=False)
if os.path.exists(diff_path): if os.path.exists(diff_path):
self.diffusion.load_state_dict(torch.load(diff_path), strict=strict) self.diffusion.load_state_dict(torch.load(diff_path), strict=strict)

View File

@ -6,8 +6,8 @@ import unicodedata
try: try:
import MeCab import MeCab
except ImportError: except ImportError as e:
raise ImportError("Japanese requires mecab-python3 and unidic-lite.") raise ImportError("Japanese requires mecab-python3 and unidic-lite.") from e
from num2words import num2words from num2words import num2words
_CONVRULES = [ _CONVRULES = [