mirror of https://github.com/coqui-ai/TTS.git
Fix Tortoise load (#2791)
* Remove key prunning in tortoise * Make lint
This commit is contained in:
parent
b3472a739e
commit
8aacb81849
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
import random
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from time import time
|
||||
|
@ -876,16 +875,12 @@ class Tortoise(BaseTTS):
|
|||
vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth")
|
||||
|
||||
if os.path.exists(ar_path):
|
||||
keys_to_ignore = self.autoregressive.gpt._keys_to_ignore_on_load_missing # pylint: disable=protected-access
|
||||
# remove keys from the checkpoint that are not in the model
|
||||
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"))
|
||||
for key in list(checkpoint.keys()):
|
||||
for pat in keys_to_ignore:
|
||||
if re.search(pat, key) is not None:
|
||||
del checkpoint[key]
|
||||
break
|
||||
|
||||
self.autoregressive.load_state_dict(checkpoint, strict=strict)
|
||||
# strict set False
|
||||
# due to removed `bias` and `masked_bias` changes in Transformers
|
||||
self.autoregressive.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
if os.path.exists(diff_path):
|
||||
self.diffusion.load_state_dict(torch.load(diff_path), strict=strict)
|
||||
|
|
|
@ -6,8 +6,8 @@ import unicodedata
|
|||
|
||||
try:
|
||||
import MeCab
|
||||
except ImportError:
|
||||
raise ImportError("Japanese requires mecab-python3 and unidic-lite.")
|
||||
except ImportError as e:
|
||||
raise ImportError("Japanese requires mecab-python3 and unidic-lite.") from e
|
||||
from num2words import num2words
|
||||
|
||||
_CONVRULES = [
|
||||
|
|
|
@ -12,4 +12,4 @@ def test_synthesize():
|
|||
'tts --model_name "coqui_studio/en/Torcull Diarmuid/coqui_studio" '
|
||||
'--text "This is it" '
|
||||
f'--out_path "{output_path}"'
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue