mirror of https://github.com/coqui-ai/TTS.git
add resnet speaker encoder train unit test
This commit is contained in:
parent
7448177b72
commit
cc192b6843
|
@ -9,7 +9,6 @@ from TTS.speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig
|
|||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
||||
config = SpeakerEncoderConfig(
|
||||
batch_size=4,
|
||||
num_speakers_in_batch=1,
|
||||
|
@ -45,3 +44,28 @@ command_train = (
|
|||
)
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
||||
|
||||
# test resnet speaker encoder
|
||||
config.model_params['model_name'] = "resnet"
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.name ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
)
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --continue_path {continue_path} "
|
||||
)
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
||||
|
|
Loading…
Reference in New Issue