chore(stream_generator): address lint issues

This commit is contained in:
Enno Hermann 2024-06-15 20:33:46 +02:00
parent 2a281237d7
commit 4d9e18ea7d
2 changed files with 8 additions and 7 deletions

View File

View File

@ -4,7 +4,7 @@ import copy
import inspect
import random
import warnings
from typing import Callable, List, Optional, Union
from typing import Callable, Optional, Union
import numpy as np
import torch
@ -21,10 +21,11 @@ from transformers import (
PreTrainedModel,
StoppingCriteriaList,
)
from transformers.generation.stopping_criteria import validate_stopping_criteria
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
def setup_seed(seed):
def setup_seed(seed: int) -> None:
if seed == -1:
return
torch.manual_seed(seed)
@ -49,9 +50,9 @@ class NewGenerationMixin(GenerationMixin):
generation_config: Optional[StreamGenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
synced_gpus: Optional[bool] = False,
seed=0,
seed: int = 0,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
r"""
@ -90,7 +91,7 @@ class NewGenerationMixin(GenerationMixin):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
@ -568,7 +569,7 @@ class NewGenerationMixin(GenerationMixin):
def typeerror():
raise ValueError(
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
"`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]`"
f"of positive integers, but is {generation_config.force_words_ids}."
)
@ -640,7 +641,7 @@ class NewGenerationMixin(GenerationMixin):
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
eos_token_id: Optional[Union[int, list[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,