mirror of https://github.com/coqui-ai/TTS.git
Add speed control for inference (#3214)
* Add speed control for inference * Fix XTTS tests * Add speed control tests
This commit is contained in:
parent
d96f3885d5
commit
04901fb2e4
|
@ -530,8 +530,10 @@ class Xtts(BaseTTS):
|
||||||
top_p=0.85,
|
top_p=0.85,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
|
speed=1.0,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
|
length_scale = 1.0 / max(speed, 0.05)
|
||||||
text = text.strip().lower()
|
text = text.strip().lower()
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
@ -584,6 +586,13 @@ class Xtts(BaseTTS):
|
||||||
gpt_latents = gpt_latents[:, :k]
|
gpt_latents = gpt_latents[:, :k]
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if length_scale != 1.0:
|
||||||
|
gpt_latents = F.interpolate(
|
||||||
|
gpt_latents.transpose(1, 2),
|
||||||
|
scale_factor=length_scale,
|
||||||
|
mode="linear"
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -634,8 +643,10 @@ class Xtts(BaseTTS):
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.85,
|
top_p=0.85,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
|
speed=1.0,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
|
length_scale = 1.0 / max(speed, 0.05)
|
||||||
text = text.strip().lower()
|
text = text.strip().lower()
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
@ -674,6 +685,12 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
||||||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||||
|
if length_scale != 1.0:
|
||||||
|
gpt_latents = F.interpolate(
|
||||||
|
gpt_latents.transpose(1, 2),
|
||||||
|
scale_factor=length_scale,
|
||||||
|
mode="linear"
|
||||||
|
).transpose(1, 2)
|
||||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||||
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
||||||
|
|
|
@ -111,7 +111,7 @@ def test_xtts_streaming():
|
||||||
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
||||||
|
|
||||||
print("Computing speaker latents...")
|
print("Computing speaker latents...")
|
||||||
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
|
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
|
||||||
|
|
||||||
print("Inference...")
|
print("Inference...")
|
||||||
chunks = model.inference_stream(
|
chunks = model.inference_stream(
|
||||||
|
@ -139,7 +139,7 @@ def test_xtts_v2():
|
||||||
"yes | "
|
"yes | "
|
||||||
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
|
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
|
||||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
|
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
|
||||||
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" "--language_idx "en"'
|
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"'
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
run_cli(
|
run_cli(
|
||||||
|
@ -164,7 +164,7 @@ def test_xtts_v2_streaming():
|
||||||
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
||||||
|
|
||||||
print("Computing speaker latents...")
|
print("Computing speaker latents...")
|
||||||
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
|
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
|
||||||
|
|
||||||
print("Inference...")
|
print("Inference...")
|
||||||
chunks = model.inference_stream(
|
chunks = model.inference_stream(
|
||||||
|
@ -179,6 +179,34 @@ def test_xtts_v2_streaming():
|
||||||
assert chunk.shape[-1] > 5000
|
assert chunk.shape[-1] > 5000
|
||||||
wav_chuncks.append(chunk)
|
wav_chuncks.append(chunk)
|
||||||
assert len(wav_chuncks) > 1
|
assert len(wav_chuncks) > 1
|
||||||
|
normal_len = sum([len(chunk) for chunk in wav_chuncks])
|
||||||
|
|
||||||
|
chunks = model.inference_stream(
|
||||||
|
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||||
|
"en",
|
||||||
|
gpt_cond_latent,
|
||||||
|
speaker_embedding,
|
||||||
|
speed=1.5
|
||||||
|
)
|
||||||
|
wav_chuncks = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
wav_chuncks.append(chunk)
|
||||||
|
fast_len = sum([len(chunk) for chunk in wav_chuncks])
|
||||||
|
|
||||||
|
chunks = model.inference_stream(
|
||||||
|
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||||
|
"en",
|
||||||
|
gpt_cond_latent,
|
||||||
|
speaker_embedding,
|
||||||
|
speed=0.66
|
||||||
|
)
|
||||||
|
wav_chuncks = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
wav_chuncks.append(chunk)
|
||||||
|
slow_len = sum([len(chunk) for chunk in wav_chuncks])
|
||||||
|
|
||||||
|
assert slow_len > normal_len
|
||||||
|
assert normal_len > fast_len
|
||||||
|
|
||||||
|
|
||||||
def test_tortoise():
|
def test_tortoise():
|
||||||
|
|
Loading…
Reference in New Issue