mirror of https://github.com/coqui-ai/TTS.git
refactor(stream_generator): update code for transformers>=4.41.1
In line with
eed9ed6798/src/transformers/generation/utils.py
This commit is contained in:
parent
4b6da4e7ba
commit
2a281237d7
|
@ -201,16 +201,15 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||
if self.config.is_encoder_decoder:
|
||||
input_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size,
|
||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||
bos_token_id=generation_config.bos_token_id,
|
||||
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=batch_size,
|
||||
model_input_name=model_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||
device=inputs_tensor.device,
|
||||
)
|
||||
else:
|
||||
# if decoder-only then inputs_tensor has to be `input_ids`
|
||||
input_ids = inputs_tensor
|
||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||
|
||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
|
|
Loading…
Reference in New Issue