mirror of https://github.com/coqui-ai/TTS.git
Comment out hack for now
This commit is contained in:
parent
61ec4322d4
commit
e8663dd3f8
|
@ -195,36 +195,41 @@ class NewGenerationMixin(GenerationMixin):
|
|||
generation_config.pad_token_id,
|
||||
generation_config.eos_token_id,
|
||||
)
|
||||
eos_token_tensor = (
|
||||
torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device)
|
||||
if generation_config.eos_token_id is not None
|
||||
else None
|
||||
)
|
||||
# pad_token_tensor = (
|
||||
# torch.tensor([generation_config.pad_token_id], device=inputs_tensor.device)
|
||||
# if generation_config.pad_token_id is not None
|
||||
# else None
|
||||
# )
|
||||
# eos_token_tensor = (
|
||||
# torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device)
|
||||
# if generation_config.eos_token_id is not None
|
||||
# else None
|
||||
# )
|
||||
|
||||
# hack to produce attention mask for mps devices since transformers bails but pytorch supports torch.isin on mps now
|
||||
# for this to work, you must run with PYTORCH_ENABLE_MPS_FALLBACK=1 and call model.to(mps_device) on the XttsModel
|
||||
if inputs_tensor.device.type == "mps":
|
||||
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
|
||||
# # hack to produce attention mask for mps devices since transformers bails but pytorch supports torch.isin on mps now
|
||||
# # for this to work, you must run with PYTORCH_ENABLE_MPS_FALLBACK=1 and call model.to(mps_device) on the XttsModel
|
||||
# if inputs_tensor.device.type == "mps":
|
||||
# default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
|
||||
|
||||
is_pad_token_in_inputs = (pad_token_tensor is not None) and (
|
||||
custom_isin(elements=inputs_tensor, test_elements=pad_token_tensor).any()
|
||||
)
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_tensor is None) or ~(
|
||||
custom_isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
|
||||
)
|
||||
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
||||
attention_mask_from_padding = inputs_tensor.ne(pad_token_tensor).long()
|
||||
# is_pad_token_in_inputs = (pad_token_tensor is not None) and (
|
||||
# custom_isin(elements=inputs_tensor, test_elements=pad_token_tensor).any()
|
||||
# )
|
||||
# is_pad_token_not_equal_to_eos_token_id = (eos_token_tensor is None) or ~(
|
||||
# custom_isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
|
||||
# )
|
||||
# can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
||||
# attention_mask_from_padding = inputs_tensor.ne(pad_token_tensor).long()
|
||||
|
||||
model_kwargs["attention_mask"] = (
|
||||
attention_mask_from_padding * can_infer_attention_mask
|
||||
+ default_attention_mask * ~can_infer_attention_mask
|
||||
)
|
||||
else:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor,
|
||||
pad_token_tensor,
|
||||
eos_token_tensor,
|
||||
)
|
||||
# model_kwargs["attention_mask"] = (
|
||||
# attention_mask_from_padding * can_infer_attention_mask
|
||||
# + default_attention_mask * ~can_infer_attention_mask
|
||||
# )
|
||||
# else:
|
||||
# model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
# inputs_tensor,
|
||||
# pad_token_tensor,
|
||||
# eos_token_tensor,
|
||||
# )
|
||||
|
||||
# decoder-only models should use left-padding for generation
|
||||
if not self.config.is_encoder_decoder:
|
||||
|
|
Loading…
Reference in New Issue