patch stream bug

This commit is contained in:
xinsen 2024-08-06 16:17:33 +08:00
parent fa28f99f15
commit 17b83e2f4e
1 changed files with 8 additions and 3 deletions

View File

@ -183,10 +183,14 @@ class NewGenerationMixin(GenerationMixin):
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
pad_token_tensor = torch.tensor([generation_config.pad_token_id],
device=inputs_tensor.device) if generation_config.pad_token_id is not None else None
eos_token_tensor = torch.tensor([generation_config.eos_token_id],
device=inputs_tensor.device) if generation_config.eos_token_id is not None else None
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor,
generation_config.pad_token_id,
generation_config.eos_token_id,
pad_token_tensor,
eos_token_tensor,
)
# decoder-only models should use left-padding for generation
@ -409,7 +413,8 @@ class NewGenerationMixin(GenerationMixin):
)
elif is_sample_gen_stream_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
# logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, device=inputs_tensor.device)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(