chore(stream_generator): address lint issues

This commit is contained in:
Enno Hermann 2024-06-15 20:33:46 +02:00
parent 2a281237d7
commit 4d9e18ea7d
2 changed files with 8 additions and 7 deletions

View File

View File

@ -4,7 +4,7 @@ import copy
import inspect import inspect
import random import random
import warnings import warnings
from typing import Callable, List, Optional, Union from typing import Callable, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -21,10 +21,11 @@ from transformers import (
PreTrainedModel, PreTrainedModel,
StoppingCriteriaList, StoppingCriteriaList,
) )
from transformers.generation.stopping_criteria import validate_stopping_criteria
from transformers.generation.utils import GenerateOutput, SampleOutput, logger from transformers.generation.utils import GenerateOutput, SampleOutput, logger
def setup_seed(seed): def setup_seed(seed: int) -> None:
if seed == -1: if seed == -1:
return return
torch.manual_seed(seed) torch.manual_seed(seed)
@ -49,9 +50,9 @@ class NewGenerationMixin(GenerationMixin):
generation_config: Optional[StreamGenerationConfig] = None, generation_config: Optional[StreamGenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
synced_gpus: Optional[bool] = False, synced_gpus: Optional[bool] = False,
seed=0, seed: int = 0,
**kwargs, **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]: ) -> Union[GenerateOutput, torch.LongTensor]:
r""" r"""
@ -90,7 +91,7 @@ class NewGenerationMixin(GenerationMixin):
Custom stopping criteria that complement the default stopping criteria built from arguments and a Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users. generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
@ -568,7 +569,7 @@ class NewGenerationMixin(GenerationMixin):
def typeerror(): def typeerror():
raise ValueError( raise ValueError(
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" "`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]`"
f"of positive integers, but is {generation_config.force_words_ids}." f"of positive integers, but is {generation_config.force_words_ids}."
) )
@ -640,7 +641,7 @@ class NewGenerationMixin(GenerationMixin):
logits_warper: Optional[LogitsProcessorList] = None, logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None, max_length: Optional[int] = None,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None, eos_token_id: Optional[Union[int, list[int]]] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None, output_scores: Optional[bool] = None,