mirror of https://github.com/coqui-ai/TTS.git
refactor(stream_generator): update special tokens for transformers>=4.41.1
Fixes #31. The handling of special tokens in `transformers` was changed in https://github.com/huggingface/transformers/pull/30624 and https://github.com/huggingface/transformers/pull/30746. This updates the XTTS streaming code accordingly.
This commit is contained in:
parent
81ac7abd58
commit
4b6da4e7ba
|
@ -151,18 +151,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
|
||||
# 3. Define model inputs
|
||||
# inputs_tensor has to be defined
|
||||
|
@ -174,6 +163,9 @@ class NewGenerationMixin(GenerationMixin):
|
|||
)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
|
||||
device = inputs_tensor.device
|
||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
|
||||
|
||||
# 4. Define other model kwargs
|
||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
|
||||
|
@ -182,7 +174,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor,
|
||||
generation_config.pad_token_id,
|
||||
|
|
|
@ -69,7 +69,7 @@ dependencies = [
|
|||
"gruut[de,es,fr]==2.2.3",
|
||||
# Tortoise
|
||||
"einops>=0.6.0",
|
||||
"transformers>=4.33.0,<4.41.0",
|
||||
"transformers>=4.41.1",
|
||||
# Bark
|
||||
"encodec>=0.1.1",
|
||||
# XTTS
|
||||
|
|
Loading…
Reference in New Issue