From 19d9f580093243d0bbf23fa6ff2c41034fb251da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 27 Apr 2021 13:27:24 +0200 Subject: [PATCH] create dummy model on the fly --- tests/test_speaker_manager.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test_speaker_manager.py b/tests/test_speaker_manager.py index b8697ca8..8fa68834 100644 --- a/tests/test_speaker_manager.py +++ b/tests/test_speaker_manager.py @@ -6,11 +6,13 @@ import torch from tests import get_tests_input_path from TTS.tts.utils.speakers import SpeakerManager +from TTS.speaker_encoder.model import SpeakerEncoder +from TTS.speaker_encoder.utils.generic_utils import save_checkpoint from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config encoder_config_path = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") -encoder_model_path = os.path.join(get_tests_input_path(), "dummy_speaker_encoder.pth.tar") +encoder_model_path = os.path.join(get_tests_input_path(), "checkpoint_0.pth.tar") sample_wav_path = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs/LJ001-0001.wav") sample_wav_path2 = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs/LJ001-0002.wav") x_vectors_file_path = os.path.join(get_tests_input_path(), "../data/dummy_speakers.json") @@ -25,6 +27,10 @@ class SpeakerManagerTest(unittest.TestCase): config = load_config(encoder_config_path) config["audio"]["resample"] = True + # create a dummy speaker encoder + model = SpeakerEncoder(**config.model) + save_checkpoint(model, None, None, get_tests_input_path(), 0, 0) + # load audio processor and speaker encoder ap = AudioProcessor(**config.audio) manager = SpeakerManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path) @@ -49,6 +55,9 @@ class SpeakerManagerTest(unittest.TestCase): assert x_vector3.shape[0] == 256 assert (x_vector - x_vector3).sum() != 0.0 + # remove dummy model + os.remove(encoder_model_path) + @staticmethod def test_speakers_file_processing(): manager = SpeakerManager(x_vectors_file_path=x_vectors_file_path)