Remove unused kwarg and added num_beams=1 as default

This commit is contained in:
Edresson Casanova 2023-11-06 10:53:25 -03:00
parent a1c441f205
commit 9e92adc5ac
2 changed files with 3 additions and 1 deletions

View File

@ -562,7 +562,7 @@ class GPT(nn.Module):
def inference(self, cond_latents, text_inputs, **hf_generate_kwargs): def inference(self, cond_latents, text_inputs, **hf_generate_kwargs):
self.compute_embeddings(cond_latents, text_inputs) self.compute_embeddings(cond_latents, text_inputs)
return self.generate(cond_latents, text_inputs, input_tokens=None, **hf_generate_kwargs) return self.generate(cond_latents, text_inputs, **hf_generate_kwargs)
def compute_embeddings( def compute_embeddings(
self, self,

View File

@ -642,6 +642,7 @@ class Xtts(BaseTTS):
diffusion_temperature=1.0, diffusion_temperature=1.0,
decoder_sampler="ddim", decoder_sampler="ddim",
decoder="hifigan", decoder="hifigan",
num_beams=1,
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
text = text.strip().lower() text = text.strip().lower()
@ -673,6 +674,7 @@ class Xtts(BaseTTS):
top_k=top_k, top_k=top_k,
temperature=temperature, temperature=temperature,
num_return_sequences=self.gpt_batch_size, num_return_sequences=self.gpt_batch_size,
num_beams=num_beams,
length_penalty=length_penalty, length_penalty=length_penalty,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
output_attentions=False, output_attentions=False,