refactor(xtts): use position embedding from tortoise

This commit is contained in:
Enno Hermann 2024-11-21 15:05:37 +01:00
parent 5ffc0543b7
commit 490c973371
2 changed files with 12 additions and 27 deletions

View File

@ -1,5 +1,6 @@
# AGPL: a notification must be added stating that changes have been made to that file. # AGPL: a notification must be added stating that changes have been made to that file.
import functools import functools
import random
from typing import Optional from typing import Optional
import torch import torch
@ -123,7 +124,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
else: else:
emb = self.embeddings(input_ids) emb = self.embeddings(input_ids)
emb = emb + self.text_pos_embedding.get_fixed_embedding( emb = emb + self.text_pos_embedding.get_fixed_embedding(
attention_mask.shape[1] - mel_len, attention_mask.device attention_mask.shape[1] - (mel_len + 1), attention_mask.device
) )
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
@ -196,18 +197,24 @@ class ConditioningEncoder(nn.Module):
class LearnedPositionEmbeddings(nn.Module): class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=0.02): def __init__(self, seq_len, model_dim, init=0.02, relative=False):
super().__init__() super().__init__()
self.emb = nn.Embedding(seq_len, model_dim) self.emb = nn.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2 # Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init) self.emb.weight.data.normal_(mean=0.0, std=init)
self.relative = relative
self.seq_len = seq_len
def forward(self, x): def forward(self, x):
sl = x.shape[1] sl = x.shape[1]
return self.emb(torch.arange(0, sl, device=x.device)) if self.relative:
start = random.randint(sl, self.seq_len) - sl
return self.emb(torch.arange(start, start + sl, device=x.device))
else:
return self.emb(torch.arange(0, sl, device=x.device))
def get_fixed_embedding(self, ind, dev): def get_fixed_embedding(self, ind, dev):
return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind] return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import GPT2Config from transformers import GPT2Config
from TTS.tts.layers.tortoise.autoregressive import _prepare_attention_mask_for_generation from TTS.tts.layers.tortoise.autoregressive import LearnedPositionEmbeddings, _prepare_attention_mask_for_generation
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
@ -18,28 +18,6 @@ def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=0.02, relative=False):
super().__init__()
# nn.Embedding
self.emb = torch.nn.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init)
self.relative = relative
self.seq_len = seq_len
def forward(self, x):
sl = x.shape[1]
if self.relative:
start = random.randint(sl, self.seq_len) - sl
return self.emb(torch.arange(start, start + sl, device=x.device))
else:
return self.emb(torch.arange(0, sl, device=x.device))
def get_fixed_embedding(self, ind, dev):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
def build_hf_gpt_transformer( def build_hf_gpt_transformer(
layers, layers,
model_dim, model_dim,