mirror of https://github.com/coqui-ai/TTS.git
fix(xtts): update streaming for transformers>=4.42.0 (#59)
* Fix Stream Generator on MacOS * Make it work on mps * Implement custom tensor.isin * Fix for latest TF * Comment out hack for now * Remove unused code * build: increase minimum transformers version * style: fix --------- Co-authored-by: Enno Hermann <Eginhard@users.noreply.github.com>
This commit is contained in:
parent
20583a496e
commit
20bbb411c2
|
@ -376,7 +376,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
|
|
||||||
elif is_sample_gen_mode:
|
elif is_sample_gen_mode:
|
||||||
# 11. prepare logits warper
|
# 11. prepare logits warper
|
||||||
logits_warper = self._get_logits_warper(generation_config)
|
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)
|
||||||
|
|
||||||
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
|
@ -401,7 +401,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
)
|
)
|
||||||
elif is_sample_gen_stream_mode:
|
elif is_sample_gen_stream_mode:
|
||||||
# 11. prepare logits warper
|
# 11. prepare logits warper
|
||||||
logits_warper = self._get_logits_warper(generation_config)
|
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)
|
||||||
|
|
||||||
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
|
@ -463,7 +463,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
|
|
||||||
elif is_beam_sample_gen_mode:
|
elif is_beam_sample_gen_mode:
|
||||||
# 11. prepare logits warper
|
# 11. prepare logits warper
|
||||||
logits_warper = self._get_logits_warper(generation_config)
|
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)
|
||||||
|
|
||||||
if stopping_criteria.max_length is None:
|
if stopping_criteria.max_length is None:
|
||||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||||
|
|
|
@ -68,7 +68,7 @@ dependencies = [
|
||||||
"gruut[de,es,fr]==2.2.3",
|
"gruut[de,es,fr]==2.2.3",
|
||||||
# Tortoise
|
# Tortoise
|
||||||
"einops>=0.6.0",
|
"einops>=0.6.0",
|
||||||
"transformers>=4.41.1",
|
"transformers>=4.42.0",
|
||||||
# Bark
|
# Bark
|
||||||
"encodec>=0.1.1",
|
"encodec>=0.1.1",
|
||||||
# XTTS
|
# XTTS
|
||||||
|
|
Loading…
Reference in New Issue