diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index cfd092f1..e28e9dec 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -4,6 +4,7 @@ from TTS.config import load_config, register_config from TTS.trainer import Trainer, TrainingArgs from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model +from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor @@ -43,8 +44,16 @@ def main(): # setup audio processor ap = AudioProcessor(**config.audio) + # init speaker manager + if config.use_speaker_embedding: + speaker_manager = SpeakerManager(data_items=train_samples + eval_samples) + elif config.use_d_vector_file: + speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file) + else: + speaker_manager = None + # init the model from config - model = setup_model(config) + model = setup_model(config, speaker_manager) # init the trainer and 🚀 trainer = Trainer(