From 8aacb81849a18ce934b8501c7e3a9f76b22dbd0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 24 Jul 2023 13:42:47 +0200 Subject: [PATCH] Fix Tortoise load (#2791) * Remove key prunning in tortoise * Make lint --- TTS/tts/models/tortoise.py | 11 +++-------- TTS/tts/utils/text/japanese/phonemizer.py | 4 ++-- tests/api_tests/test_synthesize_api.py | 2 +- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py index d9988256..2b140e56 100644 --- a/TTS/tts/models/tortoise.py +++ b/TTS/tts/models/tortoise.py @@ -1,6 +1,5 @@ import os import random -import re from contextlib import contextmanager from dataclasses import dataclass from time import time @@ -876,16 +875,12 @@ class Tortoise(BaseTTS): vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth") if os.path.exists(ar_path): - keys_to_ignore = self.autoregressive.gpt._keys_to_ignore_on_load_missing # pylint: disable=protected-access # remove keys from the checkpoint that are not in the model checkpoint = torch.load(ar_path, map_location=torch.device("cpu")) - for key in list(checkpoint.keys()): - for pat in keys_to_ignore: - if re.search(pat, key) is not None: - del checkpoint[key] - break - self.autoregressive.load_state_dict(checkpoint, strict=strict) + # strict set False + # due to removed `bias` and `masked_bias` changes in Transformers + self.autoregressive.load_state_dict(checkpoint, strict=False) if os.path.exists(diff_path): self.diffusion.load_state_dict(torch.load(diff_path), strict=strict) diff --git a/TTS/tts/utils/text/japanese/phonemizer.py b/TTS/tts/utils/text/japanese/phonemizer.py index 7f915388..c3111067 100644 --- a/TTS/tts/utils/text/japanese/phonemizer.py +++ b/TTS/tts/utils/text/japanese/phonemizer.py @@ -6,8 +6,8 @@ import unicodedata try: import MeCab -except ImportError: - raise ImportError("Japanese requires mecab-python3 and unidic-lite.") +except ImportError as e: + raise ImportError("Japanese requires mecab-python3 and unidic-lite.") from e from num2words import num2words _CONVRULES = [ diff --git a/tests/api_tests/test_synthesize_api.py b/tests/api_tests/test_synthesize_api.py index 6e1f013c..a96c8bea 100644 --- a/tests/api_tests/test_synthesize_api.py +++ b/tests/api_tests/test_synthesize_api.py @@ -12,4 +12,4 @@ def test_synthesize(): 'tts --model_name "coqui_studio/en/Torcull Diarmuid/coqui_studio" ' '--text "This is it" ' f'--out_path "{output_path}"' - ) \ No newline at end of file + )