fix(gpt): set attention mask and address other warnings

This commit is contained in:
Enno Hermann 2024-10-25 17:50:24 +02:00
parent b66c782931
commit 964b813235
4 changed files with 48 additions and 2 deletions

View File

@ -1,14 +1,22 @@
# AGPL: a notification must be added stating that changes have been made to that file. # AGPL: a notification must be added stating that changes have been made to that file.
import functools import functools
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import transformers
from packaging.version import Version
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper
if Version(transformers.__version__) >= Version("4.45"):
isin = transformers.pytorch_utils.isin_mps_friendly
else:
isin = torch.isin
def null_position_embeddings(range, dim): def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
@ -596,6 +604,8 @@ class UnifiedVoice(nn.Module):
max_length = ( max_length = (
trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
) )
stop_token_tensor = torch.tensor(self.stop_mel_token, device=inputs.device, dtype=torch.long)
attention_mask = _prepare_attention_mask_for_generation(inputs, stop_token_tensor, stop_token_tensor)
gen = self.inference_model.generate( gen = self.inference_model.generate(
inputs, inputs,
bos_token_id=self.start_mel_token, bos_token_id=self.start_mel_token,
@ -604,11 +614,39 @@ class UnifiedVoice(nn.Module):
max_length=max_length, max_length=max_length,
logits_processor=logits_processor, logits_processor=logits_processor,
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
attention_mask=attention_mask,
**hf_generate_kwargs, **hf_generate_kwargs,
) )
return gen[:, trunc_index:] return gen[:, trunc_index:]
def _prepare_attention_mask_for_generation(
inputs: torch.Tensor,
pad_token_id: Optional[torch.Tensor],
eos_token_id: Optional[torch.Tensor],
) -> torch.LongTensor:
# No information for attention mask inference -> return default attention mask
default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
if pad_token_id is None:
return default_attention_mask
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
if not is_input_ids:
return default_attention_mask
is_pad_token_in_inputs = (pad_token_id is not None) and (isin(elements=inputs, test_elements=pad_token_id).any())
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
isin(elements=eos_token_id, test_elements=pad_token_id).any()
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs.ne(pad_token_id).long()
attention_mask = (
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
)
return attention_mask
if __name__ == "__main__": if __name__ == "__main__":
gpt = UnifiedVoice( gpt = UnifiedVoice(
model_dim=256, model_dim=256,

View File

@ -8,6 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import GPT2Config from transformers import GPT2Config
from TTS.tts.layers.tortoise.autoregressive import _prepare_attention_mask_for_generation
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
@ -586,12 +587,15 @@ class GPT(nn.Module):
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
gpt_inputs = self.compute_embeddings(cond_latents, text_inputs) gpt_inputs = self.compute_embeddings(cond_latents, text_inputs)
stop_token_tensor = torch.tensor(self.stop_audio_token, device=gpt_inputs.device, dtype=torch.long)
attention_mask = _prepare_attention_mask_for_generation(gpt_inputs, stop_token_tensor, stop_token_tensor)
gen = self.gpt_inference.generate( gen = self.gpt_inference.generate(
gpt_inputs, gpt_inputs,
bos_token_id=self.start_audio_token, bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token, pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token, eos_token_id=self.stop_audio_token,
max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1], max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1],
attention_mask=attention_mask,
**hf_generate_kwargs, **hf_generate_kwargs,
) )
if "return_dict_in_generate" in hf_generate_kwargs: if "return_dict_in_generate" in hf_generate_kwargs:

View File

@ -1,10 +1,12 @@
import torch import torch
from torch import nn from torch import nn
from transformers import GPT2PreTrainedModel from transformers import GenerationMixin, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from TTS.tts.layers.xtts.stream_generator import StreamGenerationConfig
class GPT2InferenceModel(GPT2PreTrainedModel):
class GPT2InferenceModel(GPT2PreTrainedModel, GenerationMixin):
"""Override GPT2LMHeadModel to allow for prefix conditioning.""" """Override GPT2LMHeadModel to allow for prefix conditioning."""
def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache): def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
@ -15,6 +17,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
self.final_norm = norm self.final_norm = norm
self.lm_head = nn.Sequential(norm, linear) self.lm_head = nn.Sequential(norm, linear)
self.kv_cache = kv_cache self.kv_cache = kv_cache
self.generation_config = StreamGenerationConfig.from_model_config(config) if self.can_generate() else None
def store_prefix_emb(self, prefix_emb): def store_prefix_emb(self, prefix_emb):
self.cached_prefix_emb = prefix_emb self.cached_prefix_emb = prefix_emb

View File

@ -667,6 +667,7 @@ class Xtts(BaseTTS):
repetition_penalty=float(repetition_penalty), repetition_penalty=float(repetition_penalty),
output_attentions=False, output_attentions=False,
output_hidden_states=True, output_hidden_states=True,
return_dict_in_generate=True,
**hf_generate_kwargs, **hf_generate_kwargs,
) )