From a45cf83b34c558578cee2206ef81af628ab4ca72 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Thu, 28 Sep 2023 17:51:49 +0200 Subject: [PATCH] Add inference with precomputed latents --- TTS/tts/models/xtts.py | 80 ++++++++++++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 27 deletions(-) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 049f1281..aa66924a 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -478,10 +478,10 @@ class Xtts(BaseTTS): "decoder_sampler": config.decoder_sampler, } settings.update(kwargs) # allow overriding of preset settings with kwargs - return self.inference(text, ref_audio_path, language, **settings) + return self.full_inference(text, ref_audio_path, language, **settings) @torch.no_grad() - def inference( + def full_inference( self, text, ref_audio_path, @@ -557,6 +557,56 @@ class Xtts(BaseTTS): Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. Sample rate is 24kHz. """ + ( + gpt_cond_latent, + diffusion_conditioning, + speaker_embedding + ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len) + return self.inference( + text, + language, + gpt_cond_latent, + speaker_embedding, + diffusion_conditioning, + temperature=temperature, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + top_k=top_k, + top_p=top_p, + do_sample=do_sample, + decoder_iterations=decoder_iterations, + cond_free=cond_free, + cond_free_k=cond_free_k, + diffusion_temperature=diffusion_temperature, + decoder_sampler=decoder_sampler, + use_hifigan=use_hifigan, + **hf_generate_kwargs, + ) + + @torch.no_grad() + def inference( + self, + text, + language, + gpt_cond_latent, + speaker_embedding, + diffusion_conditioning, + # GPT inference + temperature=0.65, + length_penalty=1, + repetition_penalty=2.0, + top_k=50, + top_p=0.85, + do_sample=True, + # Decoder inference + decoder_iterations=100, + cond_free=True, + cond_free_k=2, + diffusion_temperature=1.0, + decoder_sampler="ddim", + use_hifigan=True, + **hf_generate_kwargs, + ): text = f"[{language}]{text.strip().lower()}" text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) @@ -564,12 +614,6 @@ class Xtts(BaseTTS): text_tokens.shape[-1] < self.args.gpt_max_text_tokens ), " ❗ XTTS can only generate text with a maximum of 400 tokens." - ( - gpt_cond_latent, - diffusion_conditioning, - speaker_embedding - ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len) - if not use_hifigan: diffuser = load_discrete_vocoder_diffuser( desired_diffusion_steps=decoder_iterations, @@ -636,18 +680,6 @@ class Xtts(BaseTTS): return {"wav": wav.cpu().numpy().squeeze()} - def inference_speaker_cond(self, ref_audio_path, gpt_cond_len=3): - ( - gpt_cond_latent, - diffusion_conditioning, - speaker_embedding - ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=3) - return { - "gpt_cond_latent": gpt_cond_latent, - "speaker_embedding": speaker_embedding, - "diffusion_conditioning": diffusion_conditioning, - } - def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): """Handle chunk formatting in streaming mode""" wav_chunk = wav_gen[:-overlap_len] @@ -668,7 +700,6 @@ class Xtts(BaseTTS): language, gpt_cond_latent, speaker_embedding, - diffusion_conditioning, # Streaming stream_chunk_size=20, overlap_wav_len=1024, @@ -678,14 +709,8 @@ class Xtts(BaseTTS): repetition_penalty=2.0, top_k=50, top_p=0.85, - gpt_cond_len=4, do_sample=True, # Decoder inference - decoder_iterations=100, - cond_free=True, - cond_free_k=2, - diffusion_temperature=1.0, - decoder_sampler="ddim", **hf_generate_kwargs, ): text = f"[{language}]{text.strip().lower()}" @@ -707,6 +732,7 @@ class Xtts(BaseTTS): repetition_penalty=float(repetition_penalty), output_attentions=False, output_hidden_states=True, + **hf_generate_kwargs, ) last_tokens = []