refactor(stream_generator): update code for transformers>=4.41.1

In line with
eed9ed6798/src/transformers/generation/utils.py
This commit is contained in:
Enno Hermann 2024-06-15 20:04:23 +02:00
parent 4b6da4e7ba
commit 2a281237d7
1 changed files with 5 additions and 6 deletions

View File

@ -201,16 +201,15 @@ class NewGenerationMixin(GenerationMixin):
# 5. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
device=inputs_tensor.device,
)
else:
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]