mirror of https://github.com/coqui-ai/TTS.git
refactor(xtts): use build_hf_gpt_transformer from tortoise
This commit is contained in:
parent
490c973371
commit
33ac0d6ee1
|
@ -217,7 +217,15 @@ class LearnedPositionEmbeddings(nn.Module):
|
|||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||
|
||||
|
||||
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
||||
def build_hf_gpt_transformer(
|
||||
layers: int,
|
||||
model_dim: int,
|
||||
heads: int,
|
||||
max_mel_seq_len: int,
|
||||
max_text_seq_len: int,
|
||||
checkpointing: bool,
|
||||
max_prompt_len: int = 0,
|
||||
):
|
||||
"""
|
||||
GPT-2 implemented by the HuggingFace library.
|
||||
"""
|
||||
|
@ -225,8 +233,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
|||
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=256, # Unused.
|
||||
n_positions=max_mel_seq_len + max_text_seq_len,
|
||||
n_ctx=max_mel_seq_len + max_text_seq_len,
|
||||
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
|
@ -239,13 +247,18 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
|||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||
# Built-in token embeddings are unused.
|
||||
del gpt.wte
|
||||
return (
|
||||
gpt,
|
||||
LearnedPositionEmbeddings(max_mel_seq_len, model_dim),
|
||||
LearnedPositionEmbeddings(max_text_seq_len, model_dim),
|
||||
None,
|
||||
None,
|
||||
|
||||
mel_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_dim)
|
||||
)
|
||||
text_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_text_seq_len, model_dim)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_dim)
|
||||
)
|
||||
return gpt, mel_pos_emb, text_pos_emb, None, None
|
||||
|
||||
|
||||
class MelEncoder(nn.Module):
|
||||
|
@ -339,12 +352,12 @@ class UnifiedVoice(nn.Module):
|
|||
self.mel_layer_pos_embedding,
|
||||
self.text_layer_pos_embedding,
|
||||
) = build_hf_gpt_transformer(
|
||||
layers,
|
||||
model_dim,
|
||||
heads,
|
||||
self.max_mel_tokens + 2 + self.max_conditioning_inputs,
|
||||
self.max_text_tokens + 2,
|
||||
checkpointing,
|
||||
layers=layers,
|
||||
model_dim=model_dim,
|
||||
heads=heads,
|
||||
max_mel_seq_len=self.max_mel_tokens + 2 + self.max_conditioning_inputs,
|
||||
max_text_seq_len=self.max_text_tokens + 2,
|
||||
checkpointing=checkpointing,
|
||||
)
|
||||
if train_solo_embeddings:
|
||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# ported from: https://github.com/neonbjb/tortoise-tts
|
||||
|
||||
import functools
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
@ -8,61 +7,16 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from transformers import GPT2Config
|
||||
|
||||
from TTS.tts.layers.tortoise.autoregressive import LearnedPositionEmbeddings, _prepare_attention_mask_for_generation
|
||||
from TTS.tts.layers.tortoise.autoregressive import (
|
||||
LearnedPositionEmbeddings,
|
||||
_prepare_attention_mask_for_generation,
|
||||
build_hf_gpt_transformer,
|
||||
)
|
||||
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
||||
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
||||
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
||||
|
||||
|
||||
def null_position_embeddings(range, dim):
|
||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||
|
||||
|
||||
def build_hf_gpt_transformer(
|
||||
layers,
|
||||
model_dim,
|
||||
heads,
|
||||
max_mel_seq_len,
|
||||
max_text_seq_len,
|
||||
max_prompt_len,
|
||||
checkpointing,
|
||||
):
|
||||
"""
|
||||
GPT-2 implemented by the HuggingFace library.
|
||||
"""
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=256, # Unused.
|
||||
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing,
|
||||
)
|
||||
gpt = GPT2Model(gpt_config)
|
||||
# Override the built in positional embeddings
|
||||
del gpt.wpe
|
||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||
# Built-in token embeddings are unused.
|
||||
del gpt.wte
|
||||
|
||||
mel_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_dim)
|
||||
)
|
||||
text_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_text_seq_len, model_dim)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_dim)
|
||||
)
|
||||
# gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True)
|
||||
return gpt, mel_pos_emb, text_pos_emb, None, None
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -127,13 +81,13 @@ class GPT(nn.Module):
|
|||
self.mel_layer_pos_embedding,
|
||||
self.text_layer_pos_embedding,
|
||||
) = build_hf_gpt_transformer(
|
||||
layers,
|
||||
model_dim,
|
||||
heads,
|
||||
self.max_mel_tokens,
|
||||
self.max_text_tokens,
|
||||
self.max_prompt_tokens,
|
||||
checkpointing,
|
||||
layers=layers,
|
||||
model_dim=model_dim,
|
||||
heads=heads,
|
||||
max_mel_seq_len=self.max_mel_tokens,
|
||||
max_text_seq_len=self.max_text_tokens,
|
||||
max_prompt_len=self.max_prompt_tokens,
|
||||
checkpointing=checkpointing,
|
||||
)
|
||||
if train_solo_embeddings:
|
||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
|
||||
|
|
Loading…
Reference in New Issue