diff --git a/TTS/tts/layers/xtts/__init__.py b/TTS/tts/layers/xtts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py index 77432480..cb098958 100644 --- a/TTS/tts/layers/xtts/stream_generator.py +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -4,7 +4,7 @@ import copy import inspect import random import warnings -from typing import Callable, List, Optional, Union +from typing import Callable, Optional, Union import numpy as np import torch @@ -21,10 +21,11 @@ from transformers import ( PreTrainedModel, StoppingCriteriaList, ) +from transformers.generation.stopping_criteria import validate_stopping_criteria from transformers.generation.utils import GenerateOutput, SampleOutput, logger -def setup_seed(seed): +def setup_seed(seed: int) -> None: if seed == -1: return torch.manual_seed(seed) @@ -49,9 +50,9 @@ class NewGenerationMixin(GenerationMixin): generation_config: Optional[StreamGenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, synced_gpus: Optional[bool] = False, - seed=0, + seed: int = 0, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" @@ -90,7 +91,7 @@ class NewGenerationMixin(GenerationMixin): Custom stopping criteria that complement the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. - prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned @@ -568,7 +569,7 @@ class NewGenerationMixin(GenerationMixin): def typeerror(): raise ValueError( - "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" + "`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]`" f"of positive integers, but is {generation_config.force_words_ids}." ) @@ -640,7 +641,7 @@ class NewGenerationMixin(GenerationMixin): logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, + eos_token_id: Optional[Union[int, list[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None,