Add multiples references on xtts inference tests

This commit is contained in:
Edresson Casanova 2023-11-06 15:09:33 -03:00 committed by Eren G??lge
parent 1b6f8d0e46
commit f444f296f2
1 changed files with 7 additions and 4 deletions

View File

@ -101,7 +101,9 @@ def test_xtts_streaming():
from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts from TTS.tts.models.xtts import Xtts
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")]
speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav")
speaker_wav.append(speaker_wav_2)
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1") model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
config = XttsConfig() config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json")) config.load_json(os.path.join(model_path, "config.json"))
@ -131,20 +133,21 @@ def test_xtts_v2():
"""XTTS is too big to run on github actions. We need to test it locally""" """XTTS is too big to run on github actions. We need to test it locally"""
output_path = os.path.join(get_tests_output_path(), "output.wav") output_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav")
use_gpu = torch.cuda.is_available() use_gpu = torch.cuda.is_available()
if use_gpu: if use_gpu:
run_cli( run_cli(
"yes | " "yes | "
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 " f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True ' f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
f'--speaker_wav "{speaker_wav}" --language_idx "en"' f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" "--language_idx "en"'
) )
else: else:
run_cli( run_cli(
"yes | " "yes | "
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 " f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False ' f'--text "This is an example." --out_path "{output_path}" --progress_bar False '
f'--speaker_wav "{speaker_wav}" --language_idx "en"' f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"'
) )
@ -153,7 +156,7 @@ def test_xtts_v2_streaming():
from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts from TTS.tts.models.xtts import Xtts
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")]
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v2") model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v2")
config = XttsConfig() config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json")) config.load_json(os.path.join(model_path, "config.json"))