refactor(xtts): use build_hf_gpt_transformer from tortoise

This commit is contained in:
Enno Hermann 2024-11-21 15:33:36 +01:00
parent 490c973371
commit 33ac0d6ee1
2 changed files with 40 additions and 73 deletions

View File

@ -217,7 +217,15 @@ class LearnedPositionEmbeddings(nn.Module):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
def build_hf_gpt_transformer(
layers: int,
model_dim: int,
heads: int,
max_mel_seq_len: int,
max_text_seq_len: int,
checkpointing: bool,
max_prompt_len: int = 0,
):
"""
GPT-2 implemented by the HuggingFace library.
"""
@ -225,8 +233,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
gpt_config = GPT2Config(
vocab_size=256, # Unused.
n_positions=max_mel_seq_len + max_text_seq_len,
n_ctx=max_mel_seq_len + max_text_seq_len,
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
@ -239,13 +247,18 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused.
del gpt.wte
return (
gpt,
LearnedPositionEmbeddings(max_mel_seq_len, model_dim),
LearnedPositionEmbeddings(max_text_seq_len, model_dim),
None,
None,
mel_pos_emb = (
LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
if max_mel_seq_len != -1
else functools.partial(null_position_embeddings, dim=model_dim)
)
text_pos_emb = (
LearnedPositionEmbeddings(max_text_seq_len, model_dim)
if max_mel_seq_len != -1
else functools.partial(null_position_embeddings, dim=model_dim)
)
return gpt, mel_pos_emb, text_pos_emb, None, None
class MelEncoder(nn.Module):
@ -339,12 +352,12 @@ class UnifiedVoice(nn.Module):
self.mel_layer_pos_embedding,
self.text_layer_pos_embedding,
) = build_hf_gpt_transformer(
layers,
model_dim,
heads,
self.max_mel_tokens + 2 + self.max_conditioning_inputs,
self.max_text_tokens + 2,
checkpointing,
layers=layers,
model_dim=model_dim,
heads=heads,
max_mel_seq_len=self.max_mel_tokens + 2 + self.max_conditioning_inputs,
max_text_seq_len=self.max_text_tokens + 2,
checkpointing=checkpointing,
)
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)

View File

@ -1,6 +1,5 @@
# ported from: https://github.com/neonbjb/tortoise-tts
import functools
import random
import torch
@ -8,61 +7,16 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config
from TTS.tts.layers.tortoise.autoregressive import LearnedPositionEmbeddings, _prepare_attention_mask_for_generation
from TTS.tts.layers.tortoise.autoregressive import (
LearnedPositionEmbeddings,
_prepare_attention_mask_for_generation,
build_hf_gpt_transformer,
)
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
def build_hf_gpt_transformer(
layers,
model_dim,
heads,
max_mel_seq_len,
max_text_seq_len,
max_prompt_len,
checkpointing,
):
"""
GPT-2 implemented by the HuggingFace library.
"""
from transformers import GPT2Config, GPT2Model
gpt_config = GPT2Config(
vocab_size=256, # Unused.
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing,
)
gpt = GPT2Model(gpt_config)
# Override the built in positional embeddings
del gpt.wpe
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused.
del gpt.wte
mel_pos_emb = (
LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
if max_mel_seq_len != -1
else functools.partial(null_position_embeddings, dim=model_dim)
)
text_pos_emb = (
LearnedPositionEmbeddings(max_text_seq_len, model_dim)
if max_mel_seq_len != -1
else functools.partial(null_position_embeddings, dim=model_dim)
)
# gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True)
return gpt, mel_pos_emb, text_pos_emb, None, None
class GPT(nn.Module):
def __init__(
self,
@ -127,13 +81,13 @@ class GPT(nn.Module):
self.mel_layer_pos_embedding,
self.text_layer_pos_embedding,
) = build_hf_gpt_transformer(
layers,
model_dim,
heads,
self.max_mel_tokens,
self.max_text_tokens,
self.max_prompt_tokens,
checkpointing,
layers=layers,
model_dim=model_dim,
heads=heads,
max_mel_seq_len=self.max_mel_tokens,
max_text_seq_len=self.max_text_tokens,
max_prompt_len=self.max_prompt_tokens,
checkpointing=checkpointing,
)
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)