mirror of https://github.com/coqui-ai/TTS.git
create dummy model on the fly
This commit is contained in:
parent
4719414f2e
commit
19d9f58009
|
@ -6,11 +6,13 @@ import torch
|
||||||
|
|
||||||
from tests import get_tests_input_path
|
from tests import get_tests_input_path
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
from TTS.speaker_encoder.model import SpeakerEncoder
|
||||||
|
from TTS.speaker_encoder.utils.generic_utils import save_checkpoint
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
|
|
||||||
encoder_config_path = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
encoder_config_path = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
||||||
encoder_model_path = os.path.join(get_tests_input_path(), "dummy_speaker_encoder.pth.tar")
|
encoder_model_path = os.path.join(get_tests_input_path(), "checkpoint_0.pth.tar")
|
||||||
sample_wav_path = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs/LJ001-0001.wav")
|
sample_wav_path = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs/LJ001-0001.wav")
|
||||||
sample_wav_path2 = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs/LJ001-0002.wav")
|
sample_wav_path2 = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs/LJ001-0002.wav")
|
||||||
x_vectors_file_path = os.path.join(get_tests_input_path(), "../data/dummy_speakers.json")
|
x_vectors_file_path = os.path.join(get_tests_input_path(), "../data/dummy_speakers.json")
|
||||||
|
@ -25,6 +27,10 @@ class SpeakerManagerTest(unittest.TestCase):
|
||||||
config = load_config(encoder_config_path)
|
config = load_config(encoder_config_path)
|
||||||
config["audio"]["resample"] = True
|
config["audio"]["resample"] = True
|
||||||
|
|
||||||
|
# create a dummy speaker encoder
|
||||||
|
model = SpeakerEncoder(**config.model)
|
||||||
|
save_checkpoint(model, None, None, get_tests_input_path(), 0, 0)
|
||||||
|
|
||||||
# load audio processor and speaker encoder
|
# load audio processor and speaker encoder
|
||||||
ap = AudioProcessor(**config.audio)
|
ap = AudioProcessor(**config.audio)
|
||||||
manager = SpeakerManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)
|
manager = SpeakerManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)
|
||||||
|
@ -49,6 +55,9 @@ class SpeakerManagerTest(unittest.TestCase):
|
||||||
assert x_vector3.shape[0] == 256
|
assert x_vector3.shape[0] == 256
|
||||||
assert (x_vector - x_vector3).sum() != 0.0
|
assert (x_vector - x_vector3).sum() != 0.0
|
||||||
|
|
||||||
|
# remove dummy model
|
||||||
|
os.remove(encoder_model_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_speakers_file_processing():
|
def test_speakers_file_processing():
|
||||||
manager = SpeakerManager(x_vectors_file_path=x_vectors_file_path)
|
manager = SpeakerManager(x_vectors_file_path=x_vectors_file_path)
|
||||||
|
|
Loading…
Reference in New Issue