From cc192b6843e7020a6665f8c699ca44b1171540c8 Mon Sep 17 00:00:00 2001 From: Edresson Date: Sat, 29 May 2021 22:43:41 -0300 Subject: [PATCH] add resnet speaker encoder train unit test --- tests/test_speaker_encoder_train.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/test_speaker_encoder_train.py b/tests/test_speaker_encoder_train.py index 831c48f2..e168a785 100644 --- a/tests/test_speaker_encoder_train.py +++ b/tests/test_speaker_encoder_train.py @@ -9,7 +9,6 @@ from TTS.speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") - config = SpeakerEncoderConfig( batch_size=4, num_speakers_in_batch=1, @@ -45,3 +44,28 @@ command_train = ( ) run_cli(command_train) shutil.rmtree(continue_path) + +# test resnet speaker encoder +config.model_params['model_name'] = "resnet" +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# restore the model and continue training for one more epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --continue_path {continue_path} " +) +run_cli(command_train) +shutil.rmtree(continue_path)