cached_property for cutlet

This commit is contained in:
WeberJulian 2023-11-07 15:12:52 +01:00
parent 9dd31038f6
commit 8487e37376
1 changed files with 7 additions and 5 deletions

View File

@ -8,6 +8,7 @@ from hangul_romanize import Transliter
from hangul_romanize.rule import academic from hangul_romanize.rule import academic
from num2words import num2words from num2words import num2words
from tokenizers import Tokenizer from tokenizers import Tokenizer
from yarl import cached_property
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
@ -535,7 +536,6 @@ DEFAULT_VOCAB_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "
class VoiceBpeTokenizer: class VoiceBpeTokenizer:
def __init__(self, vocab_file=None): def __init__(self, vocab_file=None):
self.tokenizer = None self.tokenizer = None
self.katsu = None
if vocab_file is not None: if vocab_file is not None:
self.tokenizer = Tokenizer.from_file(vocab_file) self.tokenizer = Tokenizer.from_file(vocab_file)
self.char_limits = { self.char_limits = {
@ -557,6 +557,11 @@ class VoiceBpeTokenizer:
"ko": 95, "ko": 95,
} }
@cached_property
def katsu(self):
import cutlet
return cutlet.Cutlet()
def check_input_length(self, txt, lang): def check_input_length(self, txt, lang):
limit = self.char_limits.get(lang, 250) limit = self.char_limits.get(lang, 250)
if len(txt) > limit: if len(txt) > limit:
@ -568,9 +573,6 @@ class VoiceBpeTokenizer:
if lang == "zh-cn": if lang == "zh-cn":
txt = chinese_transliterate(txt) txt = chinese_transliterate(txt)
elif lang == "ja": elif lang == "ja":
if self.katsu is None:
import cutlet
self.katsu = cutlet.Cutlet()
txt = japanese_cleaners(txt, self.katsu) txt = japanese_cleaners(txt, self.katsu)
else: else:
raise NotImplementedError() raise NotImplementedError()