From b3b22a04ca28808fd01f69b01bfc950994281f07 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Tue, 14 Nov 2023 11:53:52 +0100 Subject: [PATCH] Add speed control for inference --- TTS/tts/models/xtts.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index b277c3ac..91985912 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -530,8 +530,10 @@ class Xtts(BaseTTS): top_p=0.85, do_sample=True, num_beams=1, + speed=1.0, **hf_generate_kwargs, ): + length_scale = 1.0 / max(speed, 0.05) text = text.strip().lower() text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) @@ -584,6 +586,13 @@ class Xtts(BaseTTS): gpt_latents = gpt_latents[:, :k] break + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), + scale_factor=length_scale, + mode="linear" + ).transpose(1, 2) + wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) return { @@ -634,8 +643,10 @@ class Xtts(BaseTTS): top_k=50, top_p=0.85, do_sample=True, + speed=1.0, **hf_generate_kwargs, ): + length_scale = 1.0 / max(speed, 0.05) text = text.strip().lower() text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) @@ -674,6 +685,12 @@ class Xtts(BaseTTS): if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): gpt_latents = torch.cat(all_latents, dim=0)[None, :] + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), + scale_factor=length_scale, + mode="linear" + ).transpose(1, 2) wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len