Fix GlowTTS

This commit is contained in:
Eren Gölge 2021-11-24 18:42:44 +01:00
parent 87bf940676
commit 73d27ebd45
1 changed files with 20 additions and 0 deletions

View File

@ -14,6 +14,7 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec
@ -513,3 +514,22 @@ class GlowTTS(BaseTTS):
def on_train_step_start(self, trainer):
"""Decide on every training step wheter enable/disable data depended initialization."""
self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps
@staticmethod
def init_from_config(config: Coqpit):
"""Initialize model from config."""
# init characters
if config.use_phonemes:
from TTS.tts.utils.text.characters import IPAPhonemes
characters = IPAPhonemes().init_from_config(config)
else:
from TTS.tts.utils.text.characters import Graphemes
characters = Graphemes().init_from_config(config)
config.num_chars = characters.num_chars
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config)
return GlowTTS(config, ap, tokenizer, speaker_manager)