mirror of https://github.com/coqui-ai/TTS.git
refactor(xtts): use position embedding from tortoise
This commit is contained in:
parent
5ffc0543b7
commit
490c973371
|
@ -1,5 +1,6 @@
|
||||||
# AGPL: a notification must be added stating that changes have been made to that file.
|
# AGPL: a notification must be added stating that changes have been made to that file.
|
||||||
import functools
|
import functools
|
||||||
|
import random
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -123,7 +124,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||||
else:
|
else:
|
||||||
emb = self.embeddings(input_ids)
|
emb = self.embeddings(input_ids)
|
||||||
emb = emb + self.text_pos_embedding.get_fixed_embedding(
|
emb = emb + self.text_pos_embedding.get_fixed_embedding(
|
||||||
attention_mask.shape[1] - mel_len, attention_mask.device
|
attention_mask.shape[1] - (mel_len + 1), attention_mask.device
|
||||||
)
|
)
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
|
@ -196,18 +197,24 @@ class ConditioningEncoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class LearnedPositionEmbeddings(nn.Module):
|
class LearnedPositionEmbeddings(nn.Module):
|
||||||
def __init__(self, seq_len, model_dim, init=0.02):
|
def __init__(self, seq_len, model_dim, init=0.02, relative=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emb = nn.Embedding(seq_len, model_dim)
|
self.emb = nn.Embedding(seq_len, model_dim)
|
||||||
# Initializing this way is standard for GPT-2
|
# Initializing this way is standard for GPT-2
|
||||||
self.emb.weight.data.normal_(mean=0.0, std=init)
|
self.emb.weight.data.normal_(mean=0.0, std=init)
|
||||||
|
self.relative = relative
|
||||||
|
self.seq_len = seq_len
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
sl = x.shape[1]
|
sl = x.shape[1]
|
||||||
return self.emb(torch.arange(0, sl, device=x.device))
|
if self.relative:
|
||||||
|
start = random.randint(sl, self.seq_len) - sl
|
||||||
|
return self.emb(torch.arange(start, start + sl, device=x.device))
|
||||||
|
else:
|
||||||
|
return self.emb(torch.arange(0, sl, device=x.device))
|
||||||
|
|
||||||
def get_fixed_embedding(self, ind, dev):
|
def get_fixed_embedding(self, ind, dev):
|
||||||
return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind]
|
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
||||||
|
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import GPT2Config
|
from transformers import GPT2Config
|
||||||
|
|
||||||
from TTS.tts.layers.tortoise.autoregressive import _prepare_attention_mask_for_generation
|
from TTS.tts.layers.tortoise.autoregressive import LearnedPositionEmbeddings, _prepare_attention_mask_for_generation
|
||||||
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
||||||
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
||||||
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
||||||
|
@ -18,28 +18,6 @@ def null_position_embeddings(range, dim):
|
||||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||||
|
|
||||||
|
|
||||||
class LearnedPositionEmbeddings(nn.Module):
|
|
||||||
def __init__(self, seq_len, model_dim, init=0.02, relative=False):
|
|
||||||
super().__init__()
|
|
||||||
# nn.Embedding
|
|
||||||
self.emb = torch.nn.Embedding(seq_len, model_dim)
|
|
||||||
# Initializing this way is standard for GPT-2
|
|
||||||
self.emb.weight.data.normal_(mean=0.0, std=init)
|
|
||||||
self.relative = relative
|
|
||||||
self.seq_len = seq_len
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
sl = x.shape[1]
|
|
||||||
if self.relative:
|
|
||||||
start = random.randint(sl, self.seq_len) - sl
|
|
||||||
return self.emb(torch.arange(start, start + sl, device=x.device))
|
|
||||||
else:
|
|
||||||
return self.emb(torch.arange(0, sl, device=x.device))
|
|
||||||
|
|
||||||
def get_fixed_embedding(self, ind, dev):
|
|
||||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
def build_hf_gpt_transformer(
|
def build_hf_gpt_transformer(
|
||||||
layers,
|
layers,
|
||||||
model_dim,
|
model_dim,
|
||||||
|
|
Loading…
Reference in New Issue