From e8663dd3f871f09be4a56c68bda3562f43b2d97c Mon Sep 17 00:00:00 2001 From: Daniel Walmsley Date: Mon, 8 Jul 2024 15:55:33 -0700 Subject: [PATCH] Comment out hack for now --- TTS/tts/layers/xtts/stream_generator.py | 59 ++++++++++++++----------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py index 91905d3d..de3ae760 100644 --- a/TTS/tts/layers/xtts/stream_generator.py +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -195,36 +195,41 @@ class NewGenerationMixin(GenerationMixin): generation_config.pad_token_id, generation_config.eos_token_id, ) - eos_token_tensor = ( - torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device) - if generation_config.eos_token_id is not None - else None - ) + # pad_token_tensor = ( + # torch.tensor([generation_config.pad_token_id], device=inputs_tensor.device) + # if generation_config.pad_token_id is not None + # else None + # ) + # eos_token_tensor = ( + # torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device) + # if generation_config.eos_token_id is not None + # else None + # ) - # hack to produce attention mask for mps devices since transformers bails but pytorch supports torch.isin on mps now - # for this to work, you must run with PYTORCH_ENABLE_MPS_FALLBACK=1 and call model.to(mps_device) on the XttsModel - if inputs_tensor.device.type == "mps": - default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device) + # # hack to produce attention mask for mps devices since transformers bails but pytorch supports torch.isin on mps now + # # for this to work, you must run with PYTORCH_ENABLE_MPS_FALLBACK=1 and call model.to(mps_device) on the XttsModel + # if inputs_tensor.device.type == "mps": + # default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device) - is_pad_token_in_inputs = (pad_token_tensor is not None) and ( - custom_isin(elements=inputs_tensor, test_elements=pad_token_tensor).any() - ) - is_pad_token_not_equal_to_eos_token_id = (eos_token_tensor is None) or ~( - custom_isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any() - ) - can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id - attention_mask_from_padding = inputs_tensor.ne(pad_token_tensor).long() + # is_pad_token_in_inputs = (pad_token_tensor is not None) and ( + # custom_isin(elements=inputs_tensor, test_elements=pad_token_tensor).any() + # ) + # is_pad_token_not_equal_to_eos_token_id = (eos_token_tensor is None) or ~( + # custom_isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any() + # ) + # can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id + # attention_mask_from_padding = inputs_tensor.ne(pad_token_tensor).long() - model_kwargs["attention_mask"] = ( - attention_mask_from_padding * can_infer_attention_mask - + default_attention_mask * ~can_infer_attention_mask - ) - else: - model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, - pad_token_tensor, - eos_token_tensor, - ) + # model_kwargs["attention_mask"] = ( + # attention_mask_from_padding * can_infer_attention_mask + # + default_attention_mask * ~can_infer_attention_mask + # ) + # else: + # model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + # inputs_tensor, + # pad_token_tensor, + # eos_token_tensor, + # ) # decoder-only models should use left-padding for generation if not self.config.is_encoder_decoder: