From 5f8ad4c64b26960dad6b1399deae5f9a0a4aade2 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 29 Nov 2024 17:23:30 +0100 Subject: [PATCH] test(openvoice): add sanity check --- tests/vc_tests/test_openvoice.py | 42 ++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/vc_tests/test_openvoice.py diff --git a/tests/vc_tests/test_openvoice.py b/tests/vc_tests/test_openvoice.py new file mode 100644 index 00000000..c9f7ae39 --- /dev/null +++ b/tests/vc_tests/test_openvoice.py @@ -0,0 +1,42 @@ +import os +import unittest + +import torch + +from tests import get_tests_input_path +from TTS.vc.models.openvoice import OpenVoice, OpenVoiceConfig + +torch.manual_seed(1) +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +c = OpenVoiceConfig() + +WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") + + +class TestOpenVoice(unittest.TestCase): + + @staticmethod + def _create_inputs_inference(): + source_wav = torch.rand(16100) + target_wav = torch.rand(16000) + return source_wav, target_wav + + def test_load_audio(self): + config = OpenVoiceConfig() + model = OpenVoice(config).to(device) + wav = model.load_audio(WAV_FILE) + wav2 = model.load_audio(wav) + assert all(torch.isclose(wav, wav2)) + + def test_voice_conversion(self): + config = OpenVoiceConfig() + model = OpenVoice(config).to(device) + model.eval() + + source_wav, target_wav = self._create_inputs_inference() + output_wav = model.voice_conversion(source_wav, target_wav) + assert ( + output_wav.shape[0] == source_wav.shape[0] - source_wav.shape[0] % config.audio.hop_length + ), f"{output_wav.shape} != {source_wav.shape}"