mirror of https://github.com/coqui-ai/TTS.git
refactor(stream_generator): update special tokens for transformers>=4.41.1
Fixes #31. The handling of special tokens in `transformers` was changed in https://github.com/huggingface/transformers/pull/30624 and https://github.com/huggingface/transformers/pull/30746. This updates the XTTS streaming code accordingly.
This commit is contained in:
parent
81ac7abd58
commit
4b6da4e7ba
|
@ -151,18 +151,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
# 2. Set generation parameters if not already defined
|
# 2. Set generation parameters if not already defined
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
|
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
|
||||||
if model_kwargs.get("attention_mask", None) is None:
|
|
||||||
logger.warning(
|
|
||||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
|
||||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
|
||||||
)
|
|
||||||
eos_token_id = generation_config.eos_token_id
|
|
||||||
if isinstance(eos_token_id, list):
|
|
||||||
eos_token_id = eos_token_id[0]
|
|
||||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
|
||||||
generation_config.pad_token_id = eos_token_id
|
|
||||||
|
|
||||||
# 3. Define model inputs
|
# 3. Define model inputs
|
||||||
# inputs_tensor has to be defined
|
# inputs_tensor has to be defined
|
||||||
|
@ -174,6 +163,9 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
)
|
)
|
||||||
batch_size = inputs_tensor.shape[0]
|
batch_size = inputs_tensor.shape[0]
|
||||||
|
|
||||||
|
device = inputs_tensor.device
|
||||||
|
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
|
||||||
|
|
||||||
# 4. Define other model kwargs
|
# 4. Define other model kwargs
|
||||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||||
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
|
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
|
||||||
|
@ -182,7 +174,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||||
|
|
||||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
||||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
generation_config.pad_token_id,
|
generation_config.pad_token_id,
|
||||||
|
|
|
@ -69,7 +69,7 @@ dependencies = [
|
||||||
"gruut[de,es,fr]==2.2.3",
|
"gruut[de,es,fr]==2.2.3",
|
||||||
# Tortoise
|
# Tortoise
|
||||||
"einops>=0.6.0",
|
"einops>=0.6.0",
|
||||||
"transformers>=4.33.0,<4.41.0",
|
"transformers>=4.41.1",
|
||||||
# Bark
|
# Bark
|
||||||
"encodec>=0.1.1",
|
"encodec>=0.1.1",
|
||||||
# XTTS
|
# XTTS
|
||||||
|
|
Loading…
Reference in New Issue