Refactor GlowTTS model and recipe for TTSTokenizer

This commit is contained in:
Eren Gölge 2021-11-16 13:36:35 +01:00
parent d0eb642d88
commit 9a95e15483
2 changed files with 9 additions and 5 deletions

View File

@ -14,6 +14,7 @@ from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import Graphemes, make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
# pylint: skip-file
@ -32,9 +33,7 @@ class BaseTTS(BaseModel):
- 1D tensors `batch x 1`
"""
def __init__(
self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None
):
def __init__(self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None):
super().__init__(config)
self.config = config
self.ap = ap
@ -292,7 +291,7 @@ class BaseTTS(BaseModel):
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=self.tokenizer,
tokenizer=self.tokenizer
)
# pre-compute phonemes

View File

@ -71,7 +71,12 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=None)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
TrainingArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples
)
# AND... 3,2,1... 🚀