From 490c973371c4a5ae345982325324efd0ece7f4af Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 21 Nov 2024 15:05:37 +0100 Subject: [PATCH] refactor(xtts): use position embedding from tortoise --- TTS/tts/layers/tortoise/autoregressive.py | 15 ++++++++++---- TTS/tts/layers/xtts/gpt.py | 24 +---------------------- 2 files changed, 12 insertions(+), 27 deletions(-) diff --git a/TTS/tts/layers/tortoise/autoregressive.py b/TTS/tts/layers/tortoise/autoregressive.py index e3ffd4d1..3463e63b 100644 --- a/TTS/tts/layers/tortoise/autoregressive.py +++ b/TTS/tts/layers/tortoise/autoregressive.py @@ -1,5 +1,6 @@ # AGPL: a notification must be added stating that changes have been made to that file. import functools +import random from typing import Optional import torch @@ -123,7 +124,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel): else: emb = self.embeddings(input_ids) emb = emb + self.text_pos_embedding.get_fixed_embedding( - attention_mask.shape[1] - mel_len, attention_mask.device + attention_mask.shape[1] - (mel_len + 1), attention_mask.device ) transformer_outputs = self.transformer( @@ -196,18 +197,24 @@ class ConditioningEncoder(nn.Module): class LearnedPositionEmbeddings(nn.Module): - def __init__(self, seq_len, model_dim, init=0.02): + def __init__(self, seq_len, model_dim, init=0.02, relative=False): super().__init__() self.emb = nn.Embedding(seq_len, model_dim) # Initializing this way is standard for GPT-2 self.emb.weight.data.normal_(mean=0.0, std=init) + self.relative = relative + self.seq_len = seq_len def forward(self, x): sl = x.shape[1] - return self.emb(torch.arange(0, sl, device=x.device)) + if self.relative: + start = random.randint(sl, self.seq_len) - sl + return self.emb(torch.arange(start, start + sl, device=x.device)) + else: + return self.emb(torch.arange(0, sl, device=x.device)) def get_fixed_embedding(self, ind, dev): - return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind] + return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index b3c3b31b..f9328761 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -8,7 +8,7 @@ import torch.nn as nn import torch.nn.functional as F from transformers import GPT2Config -from TTS.tts.layers.tortoise.autoregressive import _prepare_attention_mask_for_generation +from TTS.tts.layers.tortoise.autoregressive import LearnedPositionEmbeddings, _prepare_attention_mask_for_generation from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler @@ -18,28 +18,6 @@ def null_position_embeddings(range, dim): return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) -class LearnedPositionEmbeddings(nn.Module): - def __init__(self, seq_len, model_dim, init=0.02, relative=False): - super().__init__() - # nn.Embedding - self.emb = torch.nn.Embedding(seq_len, model_dim) - # Initializing this way is standard for GPT-2 - self.emb.weight.data.normal_(mean=0.0, std=init) - self.relative = relative - self.seq_len = seq_len - - def forward(self, x): - sl = x.shape[1] - if self.relative: - start = random.randint(sl, self.seq_len) - sl - return self.emb(torch.arange(start, start + sl, device=x.device)) - else: - return self.emb(torch.arange(0, sl, device=x.device)) - - def get_fixed_embedding(self, ind, dev): - return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) - - def build_hf_gpt_transformer( layers, model_dim,