From 4eac1c4651f0e9adf3d0618cfac10ea4d4e8bd01 Mon Sep 17 00:00:00 2001 From: Edresson Date: Sun, 11 Jul 2021 12:00:39 -0300 Subject: [PATCH] bug fix on train_encoder and unit tests --- TTS/bin/train_encoder.py | 2 +- tests/test_speaker_encoder_train.py | 49 ++++++++++++++++++----------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 38902a18..2bb5bfc7 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -164,7 +164,7 @@ def main(args): # pylint: disable=redefined-outer-name elif c.loss == "angleproto": criterion = AngleProtoLoss() elif c.loss == "softmaxproto": - criterion = SoftmaxAngleProtoLoss(c.model["proj_dim"], num_speakers) + criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_speakers) else: raise Exception("The %s not is a loss supported" % c.loss) diff --git a/tests/test_speaker_encoder_train.py b/tests/test_speaker_encoder_train.py index 21b12074..4419a00f 100644 --- a/tests/test_speaker_encoder_train.py +++ b/tests/test_speaker_encoder_train.py @@ -6,7 +6,18 @@ from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig from TTS.speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig -config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +def run_test_train(): + command = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + ) + run_cli(command) + +config_path = os.path.join(get_tests_output_path(), "test_speaker_encoder_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") config = SpeakerEncoderConfig( @@ -24,16 +35,9 @@ config.audio.do_trim_silence = True config.audio.trim_db = 60 config.save_json(config_path) +print(config) # train the model for one epoch -command_train = ( - f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} " - f"--coqpit.output_path {output_path} " - "--coqpit.datasets.0.name ljspeech " - "--coqpit.datasets.0.meta_file_train metadata.csv " - "--coqpit.datasets.0.meta_file_val metadata.csv " - "--coqpit.datasets.0.path tests/data/ljspeech " -) -run_cli(command_train) +run_test_train() # Find latest folder continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) @@ -50,15 +54,7 @@ config.model_params["model_name"] = "resnet" config.save_json(config_path) # train the model for one epoch -command_train = ( - f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} " - f"--coqpit.output_path {output_path} " - "--coqpit.datasets.0.name ljspeech " - "--coqpit.datasets.0.meta_file_train metadata.csv " - "--coqpit.datasets.0.meta_file_val metadata.csv " - "--coqpit.datasets.0.path tests/data/ljspeech " -) -run_cli(command_train) +run_test_train() # Find latest folder continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) @@ -69,3 +65,18 @@ command_train = ( ) run_cli(command_train) shutil.rmtree(continue_path) + +# test model with ge2e loss function +config.loss = "ge2e" +config.save_json(config_path) +run_test_train() + +# test model with angleproto loss function +config.loss = "angleproto" +config.save_json(config_path) +run_test_train() + +# test model with softmaxproto loss function +config.loss = "softmaxproto" +config.save_json(config_path) +run_test_train()