mirror of https://github.com/coqui-ai/TTS.git
Implement custom tensor.isin
This commit is contained in:
parent
bd2f992e7e
commit
6696abfa52
|
@ -23,6 +23,20 @@ from transformers import (
|
||||||
)
|
)
|
||||||
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
|
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
|
||||||
|
|
||||||
|
def custom_isin(elements, test_elements):
|
||||||
|
# Flatten the tensors
|
||||||
|
elements_flat = elements.view(-1)
|
||||||
|
test_elements_flat = test_elements.view(-1)
|
||||||
|
|
||||||
|
# Create a mask tensor
|
||||||
|
mask = torch.zeros_like(elements_flat, dtype=torch.bool)
|
||||||
|
|
||||||
|
# Compare each element
|
||||||
|
for test_element in test_elements_flat:
|
||||||
|
mask |= (elements_flat == test_element)
|
||||||
|
|
||||||
|
# Reshape the mask to the original elements shape
|
||||||
|
return mask.view(elements.shape)
|
||||||
|
|
||||||
def setup_seed(seed):
|
def setup_seed(seed):
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
|
@ -202,10 +216,10 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
|
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
|
||||||
|
|
||||||
is_pad_token_in_inputs = (pad_token_tensor is not None) and (
|
is_pad_token_in_inputs = (pad_token_tensor is not None) and (
|
||||||
torch.isin(elements=inputs_tensor, test_elements=pad_token_tensor).any()
|
custom_isin(elements=inputs_tensor, test_elements=pad_token_tensor).any()
|
||||||
)
|
)
|
||||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_tensor is None) or ~(
|
is_pad_token_not_equal_to_eos_token_id = (eos_token_tensor is None) or ~(
|
||||||
torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
|
custom_isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
|
||||||
)
|
)
|
||||||
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
||||||
attention_mask_from_padding = inputs_tensor.ne(pad_token_tensor).long()
|
attention_mask_from_padding = inputs_tensor.ne(pad_token_tensor).long()
|
||||||
|
|
Loading…
Reference in New Issue