Add inference with precomputed latents

This commit is contained in:
WeberJulian 2023-09-28 17:51:49 +02:00
parent a0f657c764
commit a45cf83b34
1 changed files with 53 additions and 27 deletions

View File

@ -478,10 +478,10 @@ class Xtts(BaseTTS):
"decoder_sampler": config.decoder_sampler, "decoder_sampler": config.decoder_sampler,
} }
settings.update(kwargs) # allow overriding of preset settings with kwargs settings.update(kwargs) # allow overriding of preset settings with kwargs
return self.inference(text, ref_audio_path, language, **settings) return self.full_inference(text, ref_audio_path, language, **settings)
@torch.no_grad() @torch.no_grad()
def inference( def full_inference(
self, self,
text, text,
ref_audio_path, ref_audio_path,
@ -557,6 +557,56 @@ class Xtts(BaseTTS):
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
Sample rate is 24kHz. Sample rate is 24kHz.
""" """
(
gpt_cond_latent,
diffusion_conditioning,
speaker_embedding
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
return self.inference(
text,
language,
gpt_cond_latent,
speaker_embedding,
diffusion_conditioning,
temperature=temperature,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
decoder_iterations=decoder_iterations,
cond_free=cond_free,
cond_free_k=cond_free_k,
diffusion_temperature=diffusion_temperature,
decoder_sampler=decoder_sampler,
use_hifigan=use_hifigan,
**hf_generate_kwargs,
)
@torch.no_grad()
def inference(
self,
text,
language,
gpt_cond_latent,
speaker_embedding,
diffusion_conditioning,
# GPT inference
temperature=0.65,
length_penalty=1,
repetition_penalty=2.0,
top_k=50,
top_p=0.85,
do_sample=True,
# Decoder inference
decoder_iterations=100,
cond_free=True,
cond_free_k=2,
diffusion_temperature=1.0,
decoder_sampler="ddim",
use_hifigan=True,
**hf_generate_kwargs,
):
text = f"[{language}]{text.strip().lower()}" text = f"[{language}]{text.strip().lower()}"
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
@ -564,12 +614,6 @@ class Xtts(BaseTTS):
text_tokens.shape[-1] < self.args.gpt_max_text_tokens text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens." ), " ❗ XTTS can only generate text with a maximum of 400 tokens."
(
gpt_cond_latent,
diffusion_conditioning,
speaker_embedding
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
if not use_hifigan: if not use_hifigan:
diffuser = load_discrete_vocoder_diffuser( diffuser = load_discrete_vocoder_diffuser(
desired_diffusion_steps=decoder_iterations, desired_diffusion_steps=decoder_iterations,
@ -636,18 +680,6 @@ class Xtts(BaseTTS):
return {"wav": wav.cpu().numpy().squeeze()} return {"wav": wav.cpu().numpy().squeeze()}
def inference_speaker_cond(self, ref_audio_path, gpt_cond_len=3):
(
gpt_cond_latent,
diffusion_conditioning,
speaker_embedding
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=3)
return {
"gpt_cond_latent": gpt_cond_latent,
"speaker_embedding": speaker_embedding,
"diffusion_conditioning": diffusion_conditioning,
}
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
"""Handle chunk formatting in streaming mode""" """Handle chunk formatting in streaming mode"""
wav_chunk = wav_gen[:-overlap_len] wav_chunk = wav_gen[:-overlap_len]
@ -668,7 +700,6 @@ class Xtts(BaseTTS):
language, language,
gpt_cond_latent, gpt_cond_latent,
speaker_embedding, speaker_embedding,
diffusion_conditioning,
# Streaming # Streaming
stream_chunk_size=20, stream_chunk_size=20,
overlap_wav_len=1024, overlap_wav_len=1024,
@ -678,14 +709,8 @@ class Xtts(BaseTTS):
repetition_penalty=2.0, repetition_penalty=2.0,
top_k=50, top_k=50,
top_p=0.85, top_p=0.85,
gpt_cond_len=4,
do_sample=True, do_sample=True,
# Decoder inference # Decoder inference
decoder_iterations=100,
cond_free=True,
cond_free_k=2,
diffusion_temperature=1.0,
decoder_sampler="ddim",
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
text = f"[{language}]{text.strip().lower()}" text = f"[{language}]{text.strip().lower()}"
@ -707,6 +732,7 @@ class Xtts(BaseTTS):
repetition_penalty=float(repetition_penalty), repetition_penalty=float(repetition_penalty),
output_attentions=False, output_attentions=False,
output_hidden_states=True, output_hidden_states=True,
**hf_generate_kwargs,
) )
last_tokens = [] last_tokens = []