mirror of https://github.com/coqui-ai/TTS.git
fix(gpt): set attention mask and address other warnings
This commit is contained in:
parent
b66c782931
commit
964b813235
|
@ -1,14 +1,22 @@
|
|||
# AGPL: a notification must be added stating that changes have been made to that file.
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
|
||||
from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper
|
||||
|
||||
if Version(transformers.__version__) >= Version("4.45"):
|
||||
isin = transformers.pytorch_utils.isin_mps_friendly
|
||||
else:
|
||||
isin = torch.isin
|
||||
|
||||
|
||||
def null_position_embeddings(range, dim):
|
||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||
|
@ -596,6 +604,8 @@ class UnifiedVoice(nn.Module):
|
|||
max_length = (
|
||||
trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
|
||||
)
|
||||
stop_token_tensor = torch.tensor(self.stop_mel_token, device=inputs.device, dtype=torch.long)
|
||||
attention_mask = _prepare_attention_mask_for_generation(inputs, stop_token_tensor, stop_token_tensor)
|
||||
gen = self.inference_model.generate(
|
||||
inputs,
|
||||
bos_token_id=self.start_mel_token,
|
||||
|
@ -604,11 +614,39 @@ class UnifiedVoice(nn.Module):
|
|||
max_length=max_length,
|
||||
logits_processor=logits_processor,
|
||||
num_return_sequences=num_return_sequences,
|
||||
attention_mask=attention_mask,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
return gen[:, trunc_index:]
|
||||
|
||||
|
||||
def _prepare_attention_mask_for_generation(
|
||||
inputs: torch.Tensor,
|
||||
pad_token_id: Optional[torch.Tensor],
|
||||
eos_token_id: Optional[torch.Tensor],
|
||||
) -> torch.LongTensor:
|
||||
# No information for attention mask inference -> return default attention mask
|
||||
default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
|
||||
if pad_token_id is None:
|
||||
return default_attention_mask
|
||||
|
||||
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
|
||||
if not is_input_ids:
|
||||
return default_attention_mask
|
||||
|
||||
is_pad_token_in_inputs = (pad_token_id is not None) and (isin(elements=inputs, test_elements=pad_token_id).any())
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
|
||||
isin(elements=eos_token_id, test_elements=pad_token_id).any()
|
||||
)
|
||||
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
||||
attention_mask_from_padding = inputs.ne(pad_token_id).long()
|
||||
|
||||
attention_mask = (
|
||||
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
|
||||
)
|
||||
return attention_mask
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gpt = UnifiedVoice(
|
||||
model_dim=256,
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from transformers import GPT2Config
|
||||
|
||||
from TTS.tts.layers.tortoise.autoregressive import _prepare_attention_mask_for_generation
|
||||
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
||||
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
||||
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
||||
|
@ -586,12 +587,15 @@ class GPT(nn.Module):
|
|||
**hf_generate_kwargs,
|
||||
):
|
||||
gpt_inputs = self.compute_embeddings(cond_latents, text_inputs)
|
||||
stop_token_tensor = torch.tensor(self.stop_audio_token, device=gpt_inputs.device, dtype=torch.long)
|
||||
attention_mask = _prepare_attention_mask_for_generation(gpt_inputs, stop_token_tensor, stop_token_tensor)
|
||||
gen = self.gpt_inference.generate(
|
||||
gpt_inputs,
|
||||
bos_token_id=self.start_audio_token,
|
||||
pad_token_id=self.stop_audio_token,
|
||||
eos_token_id=self.stop_audio_token,
|
||||
max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1],
|
||||
attention_mask=attention_mask,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
if "return_dict_in_generate" in hf_generate_kwargs:
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPT2PreTrainedModel
|
||||
from transformers import GenerationMixin, GPT2PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
|
||||
from TTS.tts.layers.xtts.stream_generator import StreamGenerationConfig
|
||||
|
||||
class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||
|
||||
class GPT2InferenceModel(GPT2PreTrainedModel, GenerationMixin):
|
||||
"""Override GPT2LMHeadModel to allow for prefix conditioning."""
|
||||
|
||||
def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
|
||||
|
@ -15,6 +17,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
self.final_norm = norm
|
||||
self.lm_head = nn.Sequential(norm, linear)
|
||||
self.kv_cache = kv_cache
|
||||
self.generation_config = StreamGenerationConfig.from_model_config(config) if self.can_generate() else None
|
||||
|
||||
def store_prefix_emb(self, prefix_emb):
|
||||
self.cached_prefix_emb = prefix_emb
|
||||
|
|
|
@ -667,6 +667,7 @@ class Xtts(BaseTTS):
|
|||
repetition_penalty=float(repetition_penalty),
|
||||
output_attentions=False,
|
||||
output_hidden_states=True,
|
||||
return_dict_in_generate=True,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue