mirror of https://github.com/coqui-ai/TTS.git
chore(stream_generator): address lint issues
This commit is contained in:
parent
2a281237d7
commit
4d9e18ea7d
|
@ -4,7 +4,7 @@ import copy
|
|||
import inspect
|
||||
import random
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -21,10 +21,11 @@ from transformers import (
|
|||
PreTrainedModel,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from transformers.generation.stopping_criteria import validate_stopping_criteria
|
||||
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
def setup_seed(seed: int) -> None:
|
||||
if seed == -1:
|
||||
return
|
||||
torch.manual_seed(seed)
|
||||
|
@ -49,9 +50,9 @@ class NewGenerationMixin(GenerationMixin):
|
|||
generation_config: Optional[StreamGenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
||||
synced_gpus: Optional[bool] = False,
|
||||
seed=0,
|
||||
seed: int = 0,
|
||||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
r"""
|
||||
|
@ -90,7 +91,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
||||
generation config. If a stopping criteria is passed that is already created with the arguments or a
|
||||
generation config an error is thrown. This feature is intended for advanced users.
|
||||
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
|
||||
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
|
||||
If provided, this function constraints the beam search to allowed tokens only at each step. If not
|
||||
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
|
||||
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
|
||||
|
@ -568,7 +569,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
def typeerror():
|
||||
raise ValueError(
|
||||
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
|
||||
"`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]`"
|
||||
f"of positive integers, but is {generation_config.force_words_ids}."
|
||||
)
|
||||
|
||||
|
@ -640,7 +641,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
logits_warper: Optional[LogitsProcessorList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
eos_token_id: Optional[Union[int, list[int]]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
|
|
Loading…
Reference in New Issue