mirror of https://github.com/coqui-ai/TTS.git
Add inference with precomputed latents
This commit is contained in:
parent
a0f657c764
commit
a45cf83b34
|
@ -478,10 +478,10 @@ class Xtts(BaseTTS):
|
|||
"decoder_sampler": config.decoder_sampler,
|
||||
}
|
||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||
return self.inference(text, ref_audio_path, language, **settings)
|
||||
return self.full_inference(text, ref_audio_path, language, **settings)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
def full_inference(
|
||||
self,
|
||||
text,
|
||||
ref_audio_path,
|
||||
|
@ -557,6 +557,56 @@ class Xtts(BaseTTS):
|
|||
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
||||
Sample rate is 24kHz.
|
||||
"""
|
||||
(
|
||||
gpt_cond_latent,
|
||||
diffusion_conditioning,
|
||||
speaker_embedding
|
||||
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
|
||||
return self.inference(
|
||||
text,
|
||||
language,
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
diffusion_conditioning,
|
||||
temperature=temperature,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
do_sample=do_sample,
|
||||
decoder_iterations=decoder_iterations,
|
||||
cond_free=cond_free,
|
||||
cond_free_k=cond_free_k,
|
||||
diffusion_temperature=diffusion_temperature,
|
||||
decoder_sampler=decoder_sampler,
|
||||
use_hifigan=use_hifigan,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
self,
|
||||
text,
|
||||
language,
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
diffusion_conditioning,
|
||||
# GPT inference
|
||||
temperature=0.65,
|
||||
length_penalty=1,
|
||||
repetition_penalty=2.0,
|
||||
top_k=50,
|
||||
top_p=0.85,
|
||||
do_sample=True,
|
||||
# Decoder inference
|
||||
decoder_iterations=100,
|
||||
cond_free=True,
|
||||
cond_free_k=2,
|
||||
diffusion_temperature=1.0,
|
||||
decoder_sampler="ddim",
|
||||
use_hifigan=True,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
text = f"[{language}]{text.strip().lower()}"
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||
|
||||
|
@ -564,12 +614,6 @@ class Xtts(BaseTTS):
|
|||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||
|
||||
(
|
||||
gpt_cond_latent,
|
||||
diffusion_conditioning,
|
||||
speaker_embedding
|
||||
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
|
||||
|
||||
if not use_hifigan:
|
||||
diffuser = load_discrete_vocoder_diffuser(
|
||||
desired_diffusion_steps=decoder_iterations,
|
||||
|
@ -636,18 +680,6 @@ class Xtts(BaseTTS):
|
|||
|
||||
return {"wav": wav.cpu().numpy().squeeze()}
|
||||
|
||||
def inference_speaker_cond(self, ref_audio_path, gpt_cond_len=3):
|
||||
(
|
||||
gpt_cond_latent,
|
||||
diffusion_conditioning,
|
||||
speaker_embedding
|
||||
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=3)
|
||||
return {
|
||||
"gpt_cond_latent": gpt_cond_latent,
|
||||
"speaker_embedding": speaker_embedding,
|
||||
"diffusion_conditioning": diffusion_conditioning,
|
||||
}
|
||||
|
||||
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
|
||||
"""Handle chunk formatting in streaming mode"""
|
||||
wav_chunk = wav_gen[:-overlap_len]
|
||||
|
@ -668,7 +700,6 @@ class Xtts(BaseTTS):
|
|||
language,
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
diffusion_conditioning,
|
||||
# Streaming
|
||||
stream_chunk_size=20,
|
||||
overlap_wav_len=1024,
|
||||
|
@ -678,14 +709,8 @@ class Xtts(BaseTTS):
|
|||
repetition_penalty=2.0,
|
||||
top_k=50,
|
||||
top_p=0.85,
|
||||
gpt_cond_len=4,
|
||||
do_sample=True,
|
||||
# Decoder inference
|
||||
decoder_iterations=100,
|
||||
cond_free=True,
|
||||
cond_free_k=2,
|
||||
diffusion_temperature=1.0,
|
||||
decoder_sampler="ddim",
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
text = f"[{language}]{text.strip().lower()}"
|
||||
|
@ -707,6 +732,7 @@ class Xtts(BaseTTS):
|
|||
repetition_penalty=float(repetition_penalty),
|
||||
output_attentions=False,
|
||||
output_hidden_states=True,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
last_tokens = []
|
||||
|
|
Loading…
Reference in New Issue