mirror of https://github.com/coqui-ai/TTS.git
Make it work on mps
This commit is contained in:
parent
f5b81c9767
commit
bd2f992e7e
|
@ -182,14 +182,44 @@ class NewGenerationMixin(GenerationMixin):
|
|||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||
pad_token_tensor = torch.tensor([generation_config.pad_token_id], device=inputs_tensor.device) if generation_config.pad_token_id is not None else None
|
||||
eos_token_tensor = torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device) if generation_config.eos_token_id is not None else None
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor,
|
||||
pad_token_tensor,
|
||||
eos_token_tensor,
|
||||
if (
|
||||
model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask
|
||||
):
|
||||
pad_token_tensor = (
|
||||
torch.tensor([generation_config.pad_token_id], device=inputs_tensor.device)
|
||||
if generation_config.pad_token_id is not None
|
||||
else None
|
||||
)
|
||||
eos_token_tensor = (
|
||||
torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device)
|
||||
if generation_config.eos_token_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# hack to produce attention mask for mps devices since transformers bails but pytorch supports torch.isin on mps now
|
||||
# for this to work, you must run with PYTORCH_ENABLE_MPS_FALLBACK=1 and call model.to(mps_device) on the XttsModel
|
||||
if inputs_tensor.device.type == "mps":
|
||||
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
|
||||
|
||||
is_pad_token_in_inputs = (pad_token_tensor is not None) and (
|
||||
torch.isin(elements=inputs_tensor, test_elements=pad_token_tensor).any()
|
||||
)
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_tensor is None) or ~(
|
||||
torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
|
||||
)
|
||||
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
||||
attention_mask_from_padding = inputs_tensor.ne(pad_token_tensor).long()
|
||||
|
||||
model_kwargs["attention_mask"] = (
|
||||
attention_mask_from_padding * can_infer_attention_mask
|
||||
+ default_attention_mask * ~can_infer_attention_mask
|
||||
)
|
||||
else:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor,
|
||||
pad_token_tensor,
|
||||
eos_token_tensor,
|
||||
)
|
||||
|
||||
# decoder-only models should use left-padding for generation
|
||||
if not self.config.is_encoder_decoder:
|
||||
|
|
|
@ -3,8 +3,9 @@ numpy==1.22.0;python_version<="3.10"
|
|||
numpy>=1.24.3;python_version>"3.10"
|
||||
cython>=0.29.30
|
||||
scipy>=1.11.2
|
||||
torch>=2.1
|
||||
torchaudio
|
||||
torch==2.3.1
|
||||
torchaudio==2.3.1
|
||||
torchvision==0.18.1
|
||||
soundfile>=0.12.0
|
||||
librosa>=0.10.0
|
||||
scikit-learn>=1.3.0
|
||||
|
@ -48,7 +49,7 @@ bnnumerizer
|
|||
bnunicodenormalizer
|
||||
#deps for tortoise
|
||||
einops>=0.6.0
|
||||
transformers>=4.33.0
|
||||
transformers>=4.41.2
|
||||
#deps for bark
|
||||
encodec>=0.1.1
|
||||
# deps for XTTS
|
||||
|
|
Loading…
Reference in New Issue