Refactorin VITS for the tokenizer API

This commit is contained in:
Eren Gölge 2021-11-30 15:55:36 +01:00
parent 4cd690e4c1
commit 7575367b9f
1 changed files with 26 additions and 7 deletions

View File

@ -19,6 +19,7 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, se
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment from TTS.tts.utils.visual import plot_alignment
from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.models.hifigan_generator import HifiganGenerator
@ -280,19 +281,15 @@ class Vits(BaseTTS):
language_manager: LanguageManager = None, language_manager: LanguageManager = None,
): ):
super().__init__(config) super().__init__(config, ap, tokenizer, speaker_manager)
self.END2END = True self.END2END = True
self.speaker_manager = speaker_manager self.speaker_manager = speaker_manager
self.language_manager = language_manager self.language_manager = language_manager
if config.__class__.__name__ == "VitsConfig": if config.__class__.__name__ == "VitsConfig":
# loading from VitsConfig # loading from VitsConfig
if "num_chars" not in config: self.num_chars = self.tokenizer.characters.num_chars
_, self.config, num_chars = self.get_characters(config) self.config = config
config.model_args.num_chars = num_chars
else:
self.config = config
config.model_args.num_chars = config.num_chars
args = self.config.model_args args = self.config.model_args
elif isinstance(config, VitsArgs): elif isinstance(config, VitsArgs):
# loading from VitsArgs # loading from VitsArgs
@ -1039,3 +1036,25 @@ class Vits(BaseTTS):
if eval: if eval:
self.eval() self.eval()
assert not self.training assert not self.training
@staticmethod
def init_from_config(config: "Coqpit"):
"""Initialize model from config."""
# init characters
if config.use_phonemes:
from TTS.tts.utils.text.characters import IPAPhonemes
characters = IPAPhonemes().init_from_config(config)
else:
from TTS.tts.utils.text.characters import Graphemes
characters = Graphemes().init_from_config(config)
config.num_chars = characters.num_chars
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config)
return Vits(config, ap, tokenizer, speaker_manager)