mirror of https://github.com/coqui-ai/TTS.git
Update train_tts for the new API
This commit is contained in:
parent
001da8afc8
commit
8e248913d6
|
@ -44,43 +44,8 @@ def main():
|
||||||
# load training samples
|
# load training samples
|
||||||
train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True, eval_split_max_size=config.eval_split_max_size, eval_split_size=config.eval_split_size)
|
train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True, eval_split_max_size=config.eval_split_max_size, eval_split_size=config.eval_split_size)
|
||||||
|
|
||||||
# setup audio processor
|
|
||||||
ap = AudioProcessor(**config.audio)
|
|
||||||
|
|
||||||
# init speaker manager
|
|
||||||
if check_config_and_model_args(config, "use_speaker_embedding", True):
|
|
||||||
speaker_manager = SpeakerManager(data_items=train_samples + eval_samples)
|
|
||||||
if hasattr(config, "model_args"):
|
|
||||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
|
||||||
else:
|
|
||||||
config.num_speakers = speaker_manager.num_speakers
|
|
||||||
elif check_config_and_model_args(config, "use_d_vector_file", True):
|
|
||||||
if check_config_and_model_args(config, "use_speaker_encoder_as_loss", True):
|
|
||||||
speaker_manager = SpeakerManager(
|
|
||||||
d_vectors_file_path=config.model_args.d_vector_file,
|
|
||||||
encoder_model_path=config.model_args.speaker_encoder_model_path,
|
|
||||||
encoder_config_path=config.model_args.speaker_encoder_config_path,
|
|
||||||
use_cuda=torch.cuda.is_available(),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
|
|
||||||
config.num_speakers = speaker_manager.num_speakers
|
|
||||||
if hasattr(config, "model_args"):
|
|
||||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
|
||||||
else:
|
|
||||||
speaker_manager = None
|
|
||||||
|
|
||||||
if check_config_and_model_args(config, "use_language_embedding", True):
|
|
||||||
language_manager = LanguageManager(config=config)
|
|
||||||
if hasattr(config, "model_args"):
|
|
||||||
config.model_args.num_languages = language_manager.num_languages
|
|
||||||
else:
|
|
||||||
config.num_languages = language_manager.num_languages
|
|
||||||
else:
|
|
||||||
language_manager = None
|
|
||||||
|
|
||||||
# init the model from config
|
# init the model from config
|
||||||
model = setup_model(config, speaker_manager, language_manager)
|
model = setup_model(config, train_samples + eval_samples)
|
||||||
|
|
||||||
# init the trainer and 🚀
|
# init the trainer and 🚀
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
|
Loading…
Reference in New Issue