mirror of https://github.com/coqui-ai/TTS.git
fix(xtts): support transformers>=4.43.0 in streaming inference
This commit is contained in:
parent
073f8de652
commit
018d4ba1db
|
@ -20,8 +20,10 @@ from transformers import (
|
||||||
PhrasalConstraint,
|
PhrasalConstraint,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
|
TemperatureLogitsWarper,
|
||||||
|
TopKLogitsWarper,
|
||||||
|
TopPLogitsWarper,
|
||||||
)
|
)
|
||||||
from transformers.generation.stopping_criteria import validate_stopping_criteria
|
|
||||||
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
|
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
|
||||||
|
|
||||||
|
|
||||||
|
@ -152,7 +154,18 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
# 2. Set generation parameters if not already defined
|
# 2. Set generation parameters if not already defined
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
|
||||||
|
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||||
|
if model_kwargs.get("attention_mask", None) is None:
|
||||||
|
logger.warning(
|
||||||
|
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||||
|
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||||
|
)
|
||||||
|
eos_token_id = generation_config.eos_token_id
|
||||||
|
if isinstance(eos_token_id, list):
|
||||||
|
eos_token_id = eos_token_id[0]
|
||||||
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||||
|
generation_config.pad_token_id = eos_token_id
|
||||||
|
|
||||||
# 3. Define model inputs
|
# 3. Define model inputs
|
||||||
# inputs_tensor has to be defined
|
# inputs_tensor has to be defined
|
||||||
|
@ -164,22 +177,38 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
)
|
)
|
||||||
batch_size = inputs_tensor.shape[0]
|
batch_size = inputs_tensor.shape[0]
|
||||||
|
|
||||||
device = inputs_tensor.device
|
|
||||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
|
|
||||||
|
|
||||||
# 4. Define other model kwargs
|
# 4. Define other model kwargs
|
||||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||||
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
|
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
|
||||||
model_kwargs["use_cache"] = generation_config.use_cache
|
model_kwargs["use_cache"] = generation_config.use_cache
|
||||||
|
model_kwargs["cache_position"] = torch.Tensor([0]).to(inputs_tensor.device)
|
||||||
|
|
||||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||||
|
|
||||||
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||||
|
setattr(
|
||||||
|
generation_config,
|
||||||
|
"_pad_token_tensor",
|
||||||
|
torch.full(
|
||||||
|
(inputs_tensor.shape[0], inputs_tensor.shape[1]),
|
||||||
|
generation_config.pad_token_id,
|
||||||
|
device=inputs_tensor.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
generation_config,
|
||||||
|
"_eos_token_tensor",
|
||||||
|
torch.full(
|
||||||
|
(inputs_tensor.shape[0], inputs_tensor.shape[1]),
|
||||||
|
generation_config.eos_token_id,
|
||||||
|
device=inputs_tensor.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
generation_config.pad_token_id,
|
generation_config._pad_token_tensor,
|
||||||
generation_config.eos_token_id,
|
generation_config._eos_token_tensor,
|
||||||
)
|
)
|
||||||
|
|
||||||
# decoder-only models should use left-padding for generation
|
# decoder-only models should use left-padding for generation
|
||||||
|
@ -202,15 +231,16 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
|
|
||||||
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
input_ids = self._prepare_decoder_input_ids_for_generation(
|
||||||
batch_size=batch_size,
|
batch_size,
|
||||||
model_input_name=model_input_name,
|
|
||||||
model_kwargs=model_kwargs,
|
|
||||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||||
|
bos_token_id=generation_config.bos_token_id,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
device=inputs_tensor.device,
|
device=inputs_tensor.device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
# if decoder-only then inputs_tensor has to be `input_ids`
|
||||||
|
input_ids = inputs_tensor
|
||||||
|
|
||||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||||
input_ids_seq_length = input_ids.shape[-1]
|
input_ids_seq_length = input_ids.shape[-1]
|
||||||
|
@ -376,7 +406,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
|
|
||||||
elif is_sample_gen_mode:
|
elif is_sample_gen_mode:
|
||||||
# 11. prepare logits warper
|
# 11. prepare logits warper
|
||||||
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)
|
logits_warper = _get_logits_warper(generation_config)
|
||||||
|
|
||||||
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
|
@ -401,7 +431,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
)
|
)
|
||||||
elif is_sample_gen_stream_mode:
|
elif is_sample_gen_stream_mode:
|
||||||
# 11. prepare logits warper
|
# 11. prepare logits warper
|
||||||
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)
|
logits_warper = _get_logits_warper(generation_config)
|
||||||
|
|
||||||
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
|
@ -463,7 +493,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
|
|
||||||
elif is_beam_sample_gen_mode:
|
elif is_beam_sample_gen_mode:
|
||||||
# 11. prepare logits warper
|
# 11. prepare logits warper
|
||||||
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)
|
logits_warper = _get_logits_warper(generation_config)
|
||||||
|
|
||||||
if stopping_criteria.max_length is None:
|
if stopping_criteria.max_length is None:
|
||||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||||
|
@ -877,10 +907,10 @@ def init_stream_support():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
|
||||||
|
|
||||||
init_stream_support()
|
|
||||||
|
|
||||||
|
PreTrainedModel.generate = NewGenerationMixin.generate
|
||||||
|
PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
|
||||||
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16)
|
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
|
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
|
||||||
|
@ -920,3 +950,17 @@ if __name__ == "__main__":
|
||||||
chunk = tokenizer.decode(x, skip_special_tokens=True)
|
chunk = tokenizer.decode(x, skip_special_tokens=True)
|
||||||
stream_result += chunk
|
stream_result += chunk
|
||||||
print(stream_result)
|
print(stream_result)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_logits_warper(generation_config: GenerationConfig) -> LogitsProcessorList:
|
||||||
|
|
||||||
|
warpers = LogitsProcessorList()
|
||||||
|
|
||||||
|
if generation_config.temperature is not None and generation_config.temperature != 1.0:
|
||||||
|
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
|
||||||
|
if generation_config.top_k is not None and generation_config.top_k != 0:
|
||||||
|
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1))
|
||||||
|
if generation_config.top_p is not None and generation_config.top_p < 1.0:
|
||||||
|
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1))
|
||||||
|
|
||||||
|
return warpers
|
||||||
|
|
|
@ -68,7 +68,7 @@ dependencies = [
|
||||||
"gruut[de,es,fr]>=2.4.0",
|
"gruut[de,es,fr]>=2.4.0",
|
||||||
# Tortoise
|
# Tortoise
|
||||||
"einops>=0.6.0",
|
"einops>=0.6.0",
|
||||||
"transformers>=4.42.0,<4.43.0",
|
"transformers>=4.43.0",
|
||||||
# Bark
|
# Bark
|
||||||
"encodec>=0.1.1",
|
"encodec>=0.1.1",
|
||||||
# XTTS
|
# XTTS
|
||||||
|
|
Loading…
Reference in New Issue