From 018d4ba1db2c6327c595d96f83848a5bd6a98ebd Mon Sep 17 00:00:00 2001 From: Johnny Street Date: Sat, 5 Oct 2024 17:00:12 -0400 Subject: [PATCH] fix(xtts): support transformers>=4.43.0 in streaming inference --- TTS/tts/layers/xtts/stream_generator.py | 82 +++++++++++++++++++------ pyproject.toml | 2 +- 2 files changed, 64 insertions(+), 20 deletions(-) diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py index efc92a04..44cf940c 100644 --- a/TTS/tts/layers/xtts/stream_generator.py +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -20,8 +20,10 @@ from transformers import ( PhrasalConstraint, PreTrainedModel, StoppingCriteriaList, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, ) -from transformers.generation.stopping_criteria import validate_stopping_criteria from transformers.generation.utils import GenerateOutput, SampleOutput, logger @@ -152,7 +154,18 @@ class NewGenerationMixin(GenerationMixin): # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + + if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: + if model_kwargs.get("attention_mask", None) is None: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + generation_config.pad_token_id = eos_token_id # 3. Define model inputs # inputs_tensor has to be defined @@ -164,22 +177,38 @@ class NewGenerationMixin(GenerationMixin): ) batch_size = inputs_tensor.shape[0] - device = inputs_tensor.device - self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) - # 4. Define other model kwargs model_kwargs["output_attentions"] = generation_config.output_attentions model_kwargs["output_hidden_states"] = generation_config.output_hidden_states model_kwargs["use_cache"] = generation_config.use_cache + model_kwargs["cache_position"] = torch.Tensor([0]).to(inputs_tensor.device) accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs - if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: + setattr( + generation_config, + "_pad_token_tensor", + torch.full( + (inputs_tensor.shape[0], inputs_tensor.shape[1]), + generation_config.pad_token_id, + device=inputs_tensor.device, + ), + ) + setattr( + generation_config, + "_eos_token_tensor", + torch.full( + (inputs_tensor.shape[0], inputs_tensor.shape[1]), + generation_config.eos_token_id, + device=inputs_tensor.device, + ), + ) model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, - generation_config.pad_token_id, - generation_config.eos_token_id, + generation_config._pad_token_tensor, + generation_config._eos_token_tensor, ) # decoder-only models should use left-padding for generation @@ -202,15 +231,16 @@ class NewGenerationMixin(GenerationMixin): # 5. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: - input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( - batch_size=batch_size, - model_input_name=model_input_name, - model_kwargs=model_kwargs, + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, + model_kwargs=model_kwargs, device=inputs_tensor.device, ) else: - input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + # if decoder-only then inputs_tensor has to be `input_ids` + input_ids = inputs_tensor # 6. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] @@ -376,7 +406,7 @@ class NewGenerationMixin(GenerationMixin): elif is_sample_gen_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device) + logits_warper = _get_logits_warper(generation_config) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -401,7 +431,7 @@ class NewGenerationMixin(GenerationMixin): ) elif is_sample_gen_stream_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device) + logits_warper = _get_logits_warper(generation_config) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -463,7 +493,7 @@ class NewGenerationMixin(GenerationMixin): elif is_beam_sample_gen_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device) + logits_warper = _get_logits_warper(generation_config) if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") @@ -877,10 +907,10 @@ def init_stream_support(): if __name__ == "__main__": - from transformers import AutoModelForCausalLM, AutoTokenizer - - init_stream_support() + from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel + PreTrainedModel.generate = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16) tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") @@ -920,3 +950,17 @@ if __name__ == "__main__": chunk = tokenizer.decode(x, skip_special_tokens=True) stream_result += chunk print(stream_result) + + +def _get_logits_warper(generation_config: GenerationConfig) -> LogitsProcessorList: + + warpers = LogitsProcessorList() + + if generation_config.temperature is not None and generation_config.temperature != 1.0: + warpers.append(TemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_k is not None and generation_config.top_k != 0: + warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1)) + if generation_config.top_p is not None and generation_config.top_p < 1.0: + warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1)) + + return warpers diff --git a/pyproject.toml b/pyproject.toml index 4d01e91b..9b2b137d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ dependencies = [ "gruut[de,es,fr]>=2.4.0", # Tortoise "einops>=0.6.0", - "transformers>=4.42.0,<4.43.0", + "transformers>=4.43.0", # Bark "encodec>=0.1.1", # XTTS