mirror of https://github.com/coqui-ai/TTS.git
Add speed control for inference
This commit is contained in:
parent
d96f3885d5
commit
b3b22a04ca
|
@ -530,8 +530,10 @@ class Xtts(BaseTTS):
|
|||
top_p=0.85,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
speed=1.0,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
length_scale = 1.0 / max(speed, 0.05)
|
||||
text = text.strip().lower()
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||
|
||||
|
@ -584,6 +586,13 @@ class Xtts(BaseTTS):
|
|||
gpt_latents = gpt_latents[:, :k]
|
||||
break
|
||||
|
||||
if length_scale != 1.0:
|
||||
gpt_latents = F.interpolate(
|
||||
gpt_latents.transpose(1, 2),
|
||||
scale_factor=length_scale,
|
||||
mode="linear"
|
||||
).transpose(1, 2)
|
||||
|
||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||
|
||||
return {
|
||||
|
@ -634,8 +643,10 @@ class Xtts(BaseTTS):
|
|||
top_k=50,
|
||||
top_p=0.85,
|
||||
do_sample=True,
|
||||
speed=1.0,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
length_scale = 1.0 / max(speed, 0.05)
|
||||
text = text.strip().lower()
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||
|
||||
|
@ -674,6 +685,12 @@ class Xtts(BaseTTS):
|
|||
|
||||
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
||||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||
if length_scale != 1.0:
|
||||
gpt_latents = F.interpolate(
|
||||
gpt_latents.transpose(1, 2),
|
||||
scale_factor=length_scale,
|
||||
mode="linear"
|
||||
).transpose(1, 2)
|
||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
||||
|
|
Loading…
Reference in New Issue