mirror of https://github.com/coqui-ai/TTS.git
Add inference with precomputed latents
This commit is contained in:
parent
a0f657c764
commit
a45cf83b34
|
@ -478,10 +478,10 @@ class Xtts(BaseTTS):
|
||||||
"decoder_sampler": config.decoder_sampler,
|
"decoder_sampler": config.decoder_sampler,
|
||||||
}
|
}
|
||||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||||
return self.inference(text, ref_audio_path, language, **settings)
|
return self.full_inference(text, ref_audio_path, language, **settings)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(
|
def full_inference(
|
||||||
self,
|
self,
|
||||||
text,
|
text,
|
||||||
ref_audio_path,
|
ref_audio_path,
|
||||||
|
@ -557,6 +557,56 @@ class Xtts(BaseTTS):
|
||||||
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
||||||
Sample rate is 24kHz.
|
Sample rate is 24kHz.
|
||||||
"""
|
"""
|
||||||
|
(
|
||||||
|
gpt_cond_latent,
|
||||||
|
diffusion_conditioning,
|
||||||
|
speaker_embedding
|
||||||
|
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
|
||||||
|
return self.inference(
|
||||||
|
text,
|
||||||
|
language,
|
||||||
|
gpt_cond_latent,
|
||||||
|
speaker_embedding,
|
||||||
|
diffusion_conditioning,
|
||||||
|
temperature=temperature,
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
do_sample=do_sample,
|
||||||
|
decoder_iterations=decoder_iterations,
|
||||||
|
cond_free=cond_free,
|
||||||
|
cond_free_k=cond_free_k,
|
||||||
|
diffusion_temperature=diffusion_temperature,
|
||||||
|
decoder_sampler=decoder_sampler,
|
||||||
|
use_hifigan=use_hifigan,
|
||||||
|
**hf_generate_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
text,
|
||||||
|
language,
|
||||||
|
gpt_cond_latent,
|
||||||
|
speaker_embedding,
|
||||||
|
diffusion_conditioning,
|
||||||
|
# GPT inference
|
||||||
|
temperature=0.65,
|
||||||
|
length_penalty=1,
|
||||||
|
repetition_penalty=2.0,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.85,
|
||||||
|
do_sample=True,
|
||||||
|
# Decoder inference
|
||||||
|
decoder_iterations=100,
|
||||||
|
cond_free=True,
|
||||||
|
cond_free_k=2,
|
||||||
|
diffusion_temperature=1.0,
|
||||||
|
decoder_sampler="ddim",
|
||||||
|
use_hifigan=True,
|
||||||
|
**hf_generate_kwargs,
|
||||||
|
):
|
||||||
text = f"[{language}]{text.strip().lower()}"
|
text = f"[{language}]{text.strip().lower()}"
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
@ -564,12 +614,6 @@ class Xtts(BaseTTS):
|
||||||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||||
|
|
||||||
(
|
|
||||||
gpt_cond_latent,
|
|
||||||
diffusion_conditioning,
|
|
||||||
speaker_embedding
|
|
||||||
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
|
|
||||||
|
|
||||||
if not use_hifigan:
|
if not use_hifigan:
|
||||||
diffuser = load_discrete_vocoder_diffuser(
|
diffuser = load_discrete_vocoder_diffuser(
|
||||||
desired_diffusion_steps=decoder_iterations,
|
desired_diffusion_steps=decoder_iterations,
|
||||||
|
@ -636,18 +680,6 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
return {"wav": wav.cpu().numpy().squeeze()}
|
return {"wav": wav.cpu().numpy().squeeze()}
|
||||||
|
|
||||||
def inference_speaker_cond(self, ref_audio_path, gpt_cond_len=3):
|
|
||||||
(
|
|
||||||
gpt_cond_latent,
|
|
||||||
diffusion_conditioning,
|
|
||||||
speaker_embedding
|
|
||||||
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=3)
|
|
||||||
return {
|
|
||||||
"gpt_cond_latent": gpt_cond_latent,
|
|
||||||
"speaker_embedding": speaker_embedding,
|
|
||||||
"diffusion_conditioning": diffusion_conditioning,
|
|
||||||
}
|
|
||||||
|
|
||||||
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
|
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
|
||||||
"""Handle chunk formatting in streaming mode"""
|
"""Handle chunk formatting in streaming mode"""
|
||||||
wav_chunk = wav_gen[:-overlap_len]
|
wav_chunk = wav_gen[:-overlap_len]
|
||||||
|
@ -668,7 +700,6 @@ class Xtts(BaseTTS):
|
||||||
language,
|
language,
|
||||||
gpt_cond_latent,
|
gpt_cond_latent,
|
||||||
speaker_embedding,
|
speaker_embedding,
|
||||||
diffusion_conditioning,
|
|
||||||
# Streaming
|
# Streaming
|
||||||
stream_chunk_size=20,
|
stream_chunk_size=20,
|
||||||
overlap_wav_len=1024,
|
overlap_wav_len=1024,
|
||||||
|
@ -678,14 +709,8 @@ class Xtts(BaseTTS):
|
||||||
repetition_penalty=2.0,
|
repetition_penalty=2.0,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.85,
|
top_p=0.85,
|
||||||
gpt_cond_len=4,
|
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
# Decoder inference
|
# Decoder inference
|
||||||
decoder_iterations=100,
|
|
||||||
cond_free=True,
|
|
||||||
cond_free_k=2,
|
|
||||||
diffusion_temperature=1.0,
|
|
||||||
decoder_sampler="ddim",
|
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
text = f"[{language}]{text.strip().lower()}"
|
text = f"[{language}]{text.strip().lower()}"
|
||||||
|
@ -707,6 +732,7 @@ class Xtts(BaseTTS):
|
||||||
repetition_penalty=float(repetition_penalty),
|
repetition_penalty=float(repetition_penalty),
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
|
**hf_generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
last_tokens = []
|
last_tokens = []
|
||||||
|
|
Loading…
Reference in New Issue