mirror of https://github.com/coqui-ai/TTS.git
refactor(xtts): use position embedding from tortoise
This commit is contained in:
parent
5ffc0543b7
commit
490c973371
|
@ -1,5 +1,6 @@
|
|||
# AGPL: a notification must be added stating that changes have been made to that file.
|
||||
import functools
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
@ -123,7 +124,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
else:
|
||||
emb = self.embeddings(input_ids)
|
||||
emb = emb + self.text_pos_embedding.get_fixed_embedding(
|
||||
attention_mask.shape[1] - mel_len, attention_mask.device
|
||||
attention_mask.shape[1] - (mel_len + 1), attention_mask.device
|
||||
)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
|
@ -196,18 +197,24 @@ class ConditioningEncoder(nn.Module):
|
|||
|
||||
|
||||
class LearnedPositionEmbeddings(nn.Module):
|
||||
def __init__(self, seq_len, model_dim, init=0.02):
|
||||
def __init__(self, seq_len, model_dim, init=0.02, relative=False):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(seq_len, model_dim)
|
||||
# Initializing this way is standard for GPT-2
|
||||
self.emb.weight.data.normal_(mean=0.0, std=init)
|
||||
self.relative = relative
|
||||
self.seq_len = seq_len
|
||||
|
||||
def forward(self, x):
|
||||
sl = x.shape[1]
|
||||
return self.emb(torch.arange(0, sl, device=x.device))
|
||||
if self.relative:
|
||||
start = random.randint(sl, self.seq_len) - sl
|
||||
return self.emb(torch.arange(start, start + sl, device=x.device))
|
||||
else:
|
||||
return self.emb(torch.arange(0, sl, device=x.device))
|
||||
|
||||
def get_fixed_embedding(self, ind, dev):
|
||||
return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind]
|
||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||
|
||||
|
||||
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from transformers import GPT2Config
|
||||
|
||||
from TTS.tts.layers.tortoise.autoregressive import _prepare_attention_mask_for_generation
|
||||
from TTS.tts.layers.tortoise.autoregressive import LearnedPositionEmbeddings, _prepare_attention_mask_for_generation
|
||||
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
||||
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
||||
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
||||
|
@ -18,28 +18,6 @@ def null_position_embeddings(range, dim):
|
|||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||
|
||||
|
||||
class LearnedPositionEmbeddings(nn.Module):
|
||||
def __init__(self, seq_len, model_dim, init=0.02, relative=False):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.emb = torch.nn.Embedding(seq_len, model_dim)
|
||||
# Initializing this way is standard for GPT-2
|
||||
self.emb.weight.data.normal_(mean=0.0, std=init)
|
||||
self.relative = relative
|
||||
self.seq_len = seq_len
|
||||
|
||||
def forward(self, x):
|
||||
sl = x.shape[1]
|
||||
if self.relative:
|
||||
start = random.randint(sl, self.seq_len) - sl
|
||||
return self.emb(torch.arange(start, start + sl, device=x.device))
|
||||
else:
|
||||
return self.emb(torch.arange(0, sl, device=x.device))
|
||||
|
||||
def get_fixed_embedding(self, ind, dev):
|
||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||
|
||||
|
||||
def build_hf_gpt_transformer(
|
||||
layers,
|
||||
model_dim,
|
||||
|
|
Loading…
Reference in New Issue