mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #46 from idiap/fix-xtts-streaming
Fix XTTS streaming for transformers update
This commit is contained in:
commit
98c0f86cb3
|
@ -4,7 +4,7 @@ import copy
|
||||||
import inspect
|
import inspect
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -21,10 +21,11 @@ from transformers import (
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
)
|
)
|
||||||
|
from transformers.generation.stopping_criteria import validate_stopping_criteria
|
||||||
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
|
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
|
||||||
|
|
||||||
|
|
||||||
def setup_seed(seed):
|
def setup_seed(seed: int) -> None:
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
return
|
return
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
@ -49,9 +50,9 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
generation_config: Optional[StreamGenerationConfig] = None,
|
generation_config: Optional[StreamGenerationConfig] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
||||||
synced_gpus: Optional[bool] = False,
|
synced_gpus: Optional[bool] = False,
|
||||||
seed=0,
|
seed: int = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -90,7 +91,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
||||||
generation config. If a stopping criteria is passed that is already created with the arguments or a
|
generation config. If a stopping criteria is passed that is already created with the arguments or a
|
||||||
generation config an error is thrown. This feature is intended for advanced users.
|
generation config an error is thrown. This feature is intended for advanced users.
|
||||||
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
|
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
|
||||||
If provided, this function constraints the beam search to allowed tokens only at each step. If not
|
If provided, this function constraints the beam search to allowed tokens only at each step. If not
|
||||||
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
|
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
|
||||||
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
|
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
|
||||||
|
@ -151,18 +152,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
# 2. Set generation parameters if not already defined
|
# 2. Set generation parameters if not already defined
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
|
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
|
||||||
if model_kwargs.get("attention_mask", None) is None:
|
|
||||||
logger.warning(
|
|
||||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
|
||||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
|
||||||
)
|
|
||||||
eos_token_id = generation_config.eos_token_id
|
|
||||||
if isinstance(eos_token_id, list):
|
|
||||||
eos_token_id = eos_token_id[0]
|
|
||||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
|
||||||
generation_config.pad_token_id = eos_token_id
|
|
||||||
|
|
||||||
# 3. Define model inputs
|
# 3. Define model inputs
|
||||||
# inputs_tensor has to be defined
|
# inputs_tensor has to be defined
|
||||||
|
@ -174,6 +164,9 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
)
|
)
|
||||||
batch_size = inputs_tensor.shape[0]
|
batch_size = inputs_tensor.shape[0]
|
||||||
|
|
||||||
|
device = inputs_tensor.device
|
||||||
|
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
|
||||||
|
|
||||||
# 4. Define other model kwargs
|
# 4. Define other model kwargs
|
||||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||||
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
|
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
|
||||||
|
@ -182,7 +175,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||||
|
|
||||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
||||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
generation_config.pad_token_id,
|
generation_config.pad_token_id,
|
||||||
|
@ -209,16 +202,15 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
|
|
||||||
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
input_ids = self._prepare_decoder_input_ids_for_generation(
|
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
||||||
batch_size,
|
batch_size=batch_size,
|
||||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
model_input_name=model_input_name,
|
||||||
bos_token_id=generation_config.bos_token_id,
|
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
|
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||||
device=inputs_tensor.device,
|
device=inputs_tensor.device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# if decoder-only then inputs_tensor has to be `input_ids`
|
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||||
input_ids = inputs_tensor
|
|
||||||
|
|
||||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||||
input_ids_seq_length = input_ids.shape[-1]
|
input_ids_seq_length = input_ids.shape[-1]
|
||||||
|
@ -577,7 +569,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
|
|
||||||
def typeerror():
|
def typeerror():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
|
"`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]`"
|
||||||
f"of positive integers, but is {generation_config.force_words_ids}."
|
f"of positive integers, but is {generation_config.force_words_ids}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -649,7 +641,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
logits_warper: Optional[LogitsProcessorList] = None,
|
logits_warper: Optional[LogitsProcessorList] = None,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
eos_token_id: Optional[Union[int, list[int]]] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
|
|
|
@ -69,7 +69,7 @@ dependencies = [
|
||||||
"gruut[de,es,fr]==2.2.3",
|
"gruut[de,es,fr]==2.2.3",
|
||||||
# Tortoise
|
# Tortoise
|
||||||
"einops>=0.6.0",
|
"einops>=0.6.0",
|
||||||
"transformers>=4.33.0,<4.41.0",
|
"transformers>=4.41.1",
|
||||||
# Bark
|
# Bark
|
||||||
"encodec>=0.1.1",
|
"encodec>=0.1.1",
|
||||||
# XTTS
|
# XTTS
|
||||||
|
|
Loading…
Reference in New Issue