mirror of https://github.com/coqui-ai/TTS.git
fix(gpt): set attention mask and address other warnings
This commit is contained in:
parent
b66c782931
commit
964b813235
|
@ -1,14 +1,22 @@
|
||||||
# AGPL: a notification must be added stating that changes have been made to that file.
|
# AGPL: a notification must be added stating that changes have been made to that file.
|
||||||
import functools
|
import functools
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import transformers
|
||||||
|
from packaging.version import Version
|
||||||
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
|
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||||
|
|
||||||
from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper
|
from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper
|
||||||
|
|
||||||
|
if Version(transformers.__version__) >= Version("4.45"):
|
||||||
|
isin = transformers.pytorch_utils.isin_mps_friendly
|
||||||
|
else:
|
||||||
|
isin = torch.isin
|
||||||
|
|
||||||
|
|
||||||
def null_position_embeddings(range, dim):
|
def null_position_embeddings(range, dim):
|
||||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||||
|
@ -596,6 +604,8 @@ class UnifiedVoice(nn.Module):
|
||||||
max_length = (
|
max_length = (
|
||||||
trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
|
trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
|
||||||
)
|
)
|
||||||
|
stop_token_tensor = torch.tensor(self.stop_mel_token, device=inputs.device, dtype=torch.long)
|
||||||
|
attention_mask = _prepare_attention_mask_for_generation(inputs, stop_token_tensor, stop_token_tensor)
|
||||||
gen = self.inference_model.generate(
|
gen = self.inference_model.generate(
|
||||||
inputs,
|
inputs,
|
||||||
bos_token_id=self.start_mel_token,
|
bos_token_id=self.start_mel_token,
|
||||||
|
@ -604,11 +614,39 @@ class UnifiedVoice(nn.Module):
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
num_return_sequences=num_return_sequences,
|
num_return_sequences=num_return_sequences,
|
||||||
|
attention_mask=attention_mask,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
)
|
)
|
||||||
return gen[:, trunc_index:]
|
return gen[:, trunc_index:]
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_attention_mask_for_generation(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
pad_token_id: Optional[torch.Tensor],
|
||||||
|
eos_token_id: Optional[torch.Tensor],
|
||||||
|
) -> torch.LongTensor:
|
||||||
|
# No information for attention mask inference -> return default attention mask
|
||||||
|
default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
|
||||||
|
if pad_token_id is None:
|
||||||
|
return default_attention_mask
|
||||||
|
|
||||||
|
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
|
||||||
|
if not is_input_ids:
|
||||||
|
return default_attention_mask
|
||||||
|
|
||||||
|
is_pad_token_in_inputs = (pad_token_id is not None) and (isin(elements=inputs, test_elements=pad_token_id).any())
|
||||||
|
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
|
||||||
|
isin(elements=eos_token_id, test_elements=pad_token_id).any()
|
||||||
|
)
|
||||||
|
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
||||||
|
attention_mask_from_padding = inputs.ne(pad_token_id).long()
|
||||||
|
|
||||||
|
attention_mask = (
|
||||||
|
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
|
||||||
|
)
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
gpt = UnifiedVoice(
|
gpt = UnifiedVoice(
|
||||||
model_dim=256,
|
model_dim=256,
|
||||||
|
|
|
@ -8,6 +8,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import GPT2Config
|
from transformers import GPT2Config
|
||||||
|
|
||||||
|
from TTS.tts.layers.tortoise.autoregressive import _prepare_attention_mask_for_generation
|
||||||
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
||||||
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
||||||
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
||||||
|
@ -586,12 +587,15 @@ class GPT(nn.Module):
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
gpt_inputs = self.compute_embeddings(cond_latents, text_inputs)
|
gpt_inputs = self.compute_embeddings(cond_latents, text_inputs)
|
||||||
|
stop_token_tensor = torch.tensor(self.stop_audio_token, device=gpt_inputs.device, dtype=torch.long)
|
||||||
|
attention_mask = _prepare_attention_mask_for_generation(gpt_inputs, stop_token_tensor, stop_token_tensor)
|
||||||
gen = self.gpt_inference.generate(
|
gen = self.gpt_inference.generate(
|
||||||
gpt_inputs,
|
gpt_inputs,
|
||||||
bos_token_id=self.start_audio_token,
|
bos_token_id=self.start_audio_token,
|
||||||
pad_token_id=self.stop_audio_token,
|
pad_token_id=self.stop_audio_token,
|
||||||
eos_token_id=self.stop_audio_token,
|
eos_token_id=self.stop_audio_token,
|
||||||
max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1],
|
max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1],
|
||||||
|
attention_mask=attention_mask,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
)
|
)
|
||||||
if "return_dict_in_generate" in hf_generate_kwargs:
|
if "return_dict_in_generate" in hf_generate_kwargs:
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import GPT2PreTrainedModel
|
from transformers import GenerationMixin, GPT2PreTrainedModel
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||||
|
|
||||||
|
from TTS.tts.layers.xtts.stream_generator import StreamGenerationConfig
|
||||||
|
|
||||||
class GPT2InferenceModel(GPT2PreTrainedModel):
|
|
||||||
|
class GPT2InferenceModel(GPT2PreTrainedModel, GenerationMixin):
|
||||||
"""Override GPT2LMHeadModel to allow for prefix conditioning."""
|
"""Override GPT2LMHeadModel to allow for prefix conditioning."""
|
||||||
|
|
||||||
def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
|
def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
|
||||||
|
@ -15,6 +17,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||||
self.final_norm = norm
|
self.final_norm = norm
|
||||||
self.lm_head = nn.Sequential(norm, linear)
|
self.lm_head = nn.Sequential(norm, linear)
|
||||||
self.kv_cache = kv_cache
|
self.kv_cache = kv_cache
|
||||||
|
self.generation_config = StreamGenerationConfig.from_model_config(config) if self.can_generate() else None
|
||||||
|
|
||||||
def store_prefix_emb(self, prefix_emb):
|
def store_prefix_emb(self, prefix_emb):
|
||||||
self.cached_prefix_emb = prefix_emb
|
self.cached_prefix_emb = prefix_emb
|
||||||
|
|
|
@ -667,6 +667,7 @@ class Xtts(BaseTTS):
|
||||||
repetition_penalty=float(repetition_penalty),
|
repetition_penalty=float(repetition_penalty),
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue