mirror of https://github.com/coqui-ai/TTS.git
patch stream bug
This commit is contained in:
parent
fa28f99f15
commit
17b83e2f4e
|
@ -183,10 +183,14 @@ class NewGenerationMixin(GenerationMixin):
|
|||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||
pad_token_tensor = torch.tensor([generation_config.pad_token_id],
|
||||
device=inputs_tensor.device) if generation_config.pad_token_id is not None else None
|
||||
eos_token_tensor = torch.tensor([generation_config.eos_token_id],
|
||||
device=inputs_tensor.device) if generation_config.eos_token_id is not None else None
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor,
|
||||
generation_config.pad_token_id,
|
||||
generation_config.eos_token_id,
|
||||
pad_token_tensor,
|
||||
eos_token_tensor,
|
||||
)
|
||||
|
||||
# decoder-only models should use left-padding for generation
|
||||
|
@ -409,7 +413,8 @@ class NewGenerationMixin(GenerationMixin):
|
|||
)
|
||||
elif is_sample_gen_stream_mode:
|
||||
# 11. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(generation_config)
|
||||
# logits_warper = self._get_logits_warper(generation_config)
|
||||
logits_warper = self._get_logits_warper(generation_config, device=inputs_tensor.device)
|
||||
|
||||
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
|
|
Loading…
Reference in New Issue