Fix for latest TF

This commit is contained in:
Daniel Walmsley 2024-07-08 14:40:35 -07:00
parent 6696abfa52
commit bf9a38fabd
1 changed files with 3 additions and 3 deletions

View File

@ -430,7 +430,7 @@ class NewGenerationMixin(GenerationMixin):
elif is_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
@ -455,7 +455,7 @@ class NewGenerationMixin(GenerationMixin):
)
elif is_sample_gen_stream_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
@ -517,7 +517,7 @@ class NewGenerationMixin(GenerationMixin):
elif is_beam_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")