diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index 1eb66309..f34a7ac0 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -509,3 +509,6 @@ class VoiceBpeTokenizer: def __len__(self): return self.tokenizer.get_vocab_size() + + def get_number_tokens(self): + return max(self.tokenizer.get_vocab().values()) + 1 diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 784ba1be..76c5595e 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -312,7 +312,7 @@ class Xtts(BaseTTS): def init_models(self): """Initialize the models. We do it here since we need to load the tokenizer first.""" if self.tokenizer.tokenizer is not None: - self.args.gpt_number_text_tokens = max(self.tokenizer.tokenizer.get_vocab().values()) + 1 + self.args.gpt_number_text_tokens = self.tokenizer.get_number_tokens() self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]") self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")