Merge pull request #46 from idiap/fix-xtts-streaming

Fix XTTS streaming for transformers update
This commit is contained in:
Enno Hermann 2024-06-18 14:54:15 +01:00 committed by GitHub
commit 98c0f86cb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 27 deletions

View File

View File

@ -4,7 +4,7 @@ import copy
import inspect import inspect
import random import random
import warnings import warnings
from typing import Callable, List, Optional, Union from typing import Callable, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -21,10 +21,11 @@ from transformers import (
PreTrainedModel, PreTrainedModel,
StoppingCriteriaList, StoppingCriteriaList,
) )
from transformers.generation.stopping_criteria import validate_stopping_criteria
from transformers.generation.utils import GenerateOutput, SampleOutput, logger from transformers.generation.utils import GenerateOutput, SampleOutput, logger
def setup_seed(seed): def setup_seed(seed: int) -> None:
if seed == -1: if seed == -1:
return return
torch.manual_seed(seed) torch.manual_seed(seed)
@ -49,9 +50,9 @@ class NewGenerationMixin(GenerationMixin):
generation_config: Optional[StreamGenerationConfig] = None, generation_config: Optional[StreamGenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
synced_gpus: Optional[bool] = False, synced_gpus: Optional[bool] = False,
seed=0, seed: int = 0,
**kwargs, **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]: ) -> Union[GenerateOutput, torch.LongTensor]:
r""" r"""
@ -90,7 +91,7 @@ class NewGenerationMixin(GenerationMixin):
Custom stopping criteria that complement the default stopping criteria built from arguments and a Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users. generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
@ -151,18 +152,7 @@ class NewGenerationMixin(GenerationMixin):
# 2. Set generation parameters if not already defined # 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
# 3. Define model inputs # 3. Define model inputs
# inputs_tensor has to be defined # inputs_tensor has to be defined
@ -174,6 +164,9 @@ class NewGenerationMixin(GenerationMixin):
) )
batch_size = inputs_tensor.shape[0] batch_size = inputs_tensor.shape[0]
device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
# 4. Define other model kwargs # 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
@ -182,7 +175,7 @@ class NewGenerationMixin(GenerationMixin):
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, inputs_tensor,
generation_config.pad_token_id, generation_config.pad_token_id,
@ -209,16 +202,15 @@ class NewGenerationMixin(GenerationMixin):
# 5. Prepare `input_ids` which will be used for auto-regressive generation # 5. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
input_ids = self._prepare_decoder_input_ids_for_generation( input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size, batch_size=batch_size,
decoder_start_token_id=generation_config.decoder_start_token_id, model_input_name=model_input_name,
bos_token_id=generation_config.bos_token_id,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
device=inputs_tensor.device, device=inputs_tensor.device,
) )
else: else:
# if decoder-only then inputs_tensor has to be `input_ids` input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
input_ids = inputs_tensor
# 6. Prepare `max_length` depending on other stopping criteria. # 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1] input_ids_seq_length = input_ids.shape[-1]
@ -577,7 +569,7 @@ class NewGenerationMixin(GenerationMixin):
def typeerror(): def typeerror():
raise ValueError( raise ValueError(
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" "`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]`"
f"of positive integers, but is {generation_config.force_words_ids}." f"of positive integers, but is {generation_config.force_words_ids}."
) )
@ -649,7 +641,7 @@ class NewGenerationMixin(GenerationMixin):
logits_warper: Optional[LogitsProcessorList] = None, logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None, max_length: Optional[int] = None,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None, eos_token_id: Optional[Union[int, list[int]]] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None, output_scores: Optional[bool] = None,

View File

@ -69,7 +69,7 @@ dependencies = [
"gruut[de,es,fr]==2.2.3", "gruut[de,es,fr]==2.2.3",
# Tortoise # Tortoise
"einops>=0.6.0", "einops>=0.6.0",
"transformers>=4.33.0,<4.41.0", "transformers>=4.41.1",
# Bark # Bark
"encodec>=0.1.1", "encodec>=0.1.1",
# XTTS # XTTS