Add streaming test

This commit is contained in:
WeberJulian 2023-10-06 14:50:30 +02:00
parent a097541ed4
commit a357f81ded
1 changed files with 28 additions and 0 deletions

View File

@ -93,6 +93,34 @@ def test_xtts():
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
)
def test_xtts_streaming():
"""Testing the new inference_stream method"""
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json"))
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir=model_path)
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Computing speaker latents...")
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
print("Inference...")
chunks = model.inference_stream(
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
"en",
gpt_cond_latent,
speaker_embedding
)
wav_chuncks = []
for i, chunk in enumerate(chunks):
if i == 0:
assert chunk.shape[-1] > 5000
wav_chuncks.append(chunk)
assert len(wav_chuncks) > 1
def test_tortoise():
output_path = os.path.join(get_tests_output_path(), "output.wav")