From 2a281237d7f97608e88a32056f41a239b5dd6e77 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 15 Jun 2024 20:04:23 +0200 Subject: [PATCH] refactor(stream_generator): update code for transformers>=4.41.1 In line with https://github.com/huggingface/transformers/blob/eed9ed679878ada2f6d2eefccdbda368cabc88b1/src/transformers/generation/utils.py --- TTS/tts/layers/xtts/stream_generator.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py index 6a152856..77432480 100644 --- a/TTS/tts/layers/xtts/stream_generator.py +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -201,16 +201,15 @@ class NewGenerationMixin(GenerationMixin): # 5. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: - input_ids = self._prepare_decoder_input_ids_for_generation( - batch_size, - decoder_start_token_id=generation_config.decoder_start_token_id, - bos_token_id=generation_config.bos_token_id, + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=batch_size, + model_input_name=model_input_name, model_kwargs=model_kwargs, + decoder_start_token_id=generation_config.decoder_start_token_id, device=inputs_tensor.device, ) else: - # if decoder-only then inputs_tensor has to be `input_ids` - input_ids = inputs_tensor + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") # 6. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1]