mirror of https://github.com/coqui-ai/TTS.git
Fix Tortoise load (#2697)
* Handle missing gpt weights * Make style * Fix lint
This commit is contained in:
parent
d65819422b
commit
4cf8652392
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from time import time
|
from time import time
|
||||||
|
@ -871,7 +872,16 @@ class Tortoise(BaseTTS):
|
||||||
vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth")
|
vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth")
|
||||||
|
|
||||||
if os.path.exists(ar_path):
|
if os.path.exists(ar_path):
|
||||||
self.autoregressive.load_state_dict(torch.load(ar_path), strict=strict)
|
keys_to_ignore = self.autoregressive.gpt._keys_to_ignore_on_load_missing # pylint: disable=protected-access
|
||||||
|
# remove keys from the checkpoint that are not in the model
|
||||||
|
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"))
|
||||||
|
for key in list(checkpoint.keys()):
|
||||||
|
for pat in keys_to_ignore:
|
||||||
|
if re.search(pat, key) is not None:
|
||||||
|
del checkpoint[key]
|
||||||
|
break
|
||||||
|
|
||||||
|
self.autoregressive.load_state_dict(checkpoint, strict=strict)
|
||||||
|
|
||||||
if os.path.exists(diff_path):
|
if os.path.exists(diff_path):
|
||||||
self.diffusion.load_state_dict(torch.load(diff_path), strict=strict)
|
self.diffusion.load_state_dict(torch.load(diff_path), strict=strict)
|
||||||
|
|
Loading…
Reference in New Issue