mirror of https://github.com/coqui-ai/TTS.git
refactor(xtts): use build_hf_gpt_transformer from tortoise
This commit is contained in:
parent
490c973371
commit
33ac0d6ee1
|
@ -217,7 +217,15 @@ class LearnedPositionEmbeddings(nn.Module):
|
||||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
def build_hf_gpt_transformer(
|
||||||
|
layers: int,
|
||||||
|
model_dim: int,
|
||||||
|
heads: int,
|
||||||
|
max_mel_seq_len: int,
|
||||||
|
max_text_seq_len: int,
|
||||||
|
checkpointing: bool,
|
||||||
|
max_prompt_len: int = 0,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
GPT-2 implemented by the HuggingFace library.
|
GPT-2 implemented by the HuggingFace library.
|
||||||
"""
|
"""
|
||||||
|
@ -225,8 +233,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
||||||
|
|
||||||
gpt_config = GPT2Config(
|
gpt_config = GPT2Config(
|
||||||
vocab_size=256, # Unused.
|
vocab_size=256, # Unused.
|
||||||
n_positions=max_mel_seq_len + max_text_seq_len,
|
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||||
n_ctx=max_mel_seq_len + max_text_seq_len,
|
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||||
n_embd=model_dim,
|
n_embd=model_dim,
|
||||||
n_layer=layers,
|
n_layer=layers,
|
||||||
n_head=heads,
|
n_head=heads,
|
||||||
|
@ -239,13 +247,18 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
||||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
# Built-in token embeddings are unused.
|
# Built-in token embeddings are unused.
|
||||||
del gpt.wte
|
del gpt.wte
|
||||||
return (
|
|
||||||
gpt,
|
mel_pos_emb = (
|
||||||
LearnedPositionEmbeddings(max_mel_seq_len, model_dim),
|
LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
|
||||||
LearnedPositionEmbeddings(max_text_seq_len, model_dim),
|
if max_mel_seq_len != -1
|
||||||
None,
|
else functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
text_pos_emb = (
|
||||||
|
LearnedPositionEmbeddings(max_text_seq_len, model_dim)
|
||||||
|
if max_mel_seq_len != -1
|
||||||
|
else functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
|
)
|
||||||
|
return gpt, mel_pos_emb, text_pos_emb, None, None
|
||||||
|
|
||||||
|
|
||||||
class MelEncoder(nn.Module):
|
class MelEncoder(nn.Module):
|
||||||
|
@ -339,12 +352,12 @@ class UnifiedVoice(nn.Module):
|
||||||
self.mel_layer_pos_embedding,
|
self.mel_layer_pos_embedding,
|
||||||
self.text_layer_pos_embedding,
|
self.text_layer_pos_embedding,
|
||||||
) = build_hf_gpt_transformer(
|
) = build_hf_gpt_transformer(
|
||||||
layers,
|
layers=layers,
|
||||||
model_dim,
|
model_dim=model_dim,
|
||||||
heads,
|
heads=heads,
|
||||||
self.max_mel_tokens + 2 + self.max_conditioning_inputs,
|
max_mel_seq_len=self.max_mel_tokens + 2 + self.max_conditioning_inputs,
|
||||||
self.max_text_tokens + 2,
|
max_text_seq_len=self.max_text_tokens + 2,
|
||||||
checkpointing,
|
checkpointing=checkpointing,
|
||||||
)
|
)
|
||||||
if train_solo_embeddings:
|
if train_solo_embeddings:
|
||||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
|
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
# ported from: https://github.com/neonbjb/tortoise-tts
|
# ported from: https://github.com/neonbjb/tortoise-tts
|
||||||
|
|
||||||
import functools
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -8,61 +7,16 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import GPT2Config
|
from transformers import GPT2Config
|
||||||
|
|
||||||
from TTS.tts.layers.tortoise.autoregressive import LearnedPositionEmbeddings, _prepare_attention_mask_for_generation
|
from TTS.tts.layers.tortoise.autoregressive import (
|
||||||
|
LearnedPositionEmbeddings,
|
||||||
|
_prepare_attention_mask_for_generation,
|
||||||
|
build_hf_gpt_transformer,
|
||||||
|
)
|
||||||
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
||||||
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
||||||
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
||||||
|
|
||||||
|
|
||||||
def null_position_embeddings(range, dim):
|
|
||||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
|
||||||
|
|
||||||
|
|
||||||
def build_hf_gpt_transformer(
|
|
||||||
layers,
|
|
||||||
model_dim,
|
|
||||||
heads,
|
|
||||||
max_mel_seq_len,
|
|
||||||
max_text_seq_len,
|
|
||||||
max_prompt_len,
|
|
||||||
checkpointing,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
GPT-2 implemented by the HuggingFace library.
|
|
||||||
"""
|
|
||||||
from transformers import GPT2Config, GPT2Model
|
|
||||||
|
|
||||||
gpt_config = GPT2Config(
|
|
||||||
vocab_size=256, # Unused.
|
|
||||||
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
|
||||||
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
|
||||||
n_embd=model_dim,
|
|
||||||
n_layer=layers,
|
|
||||||
n_head=heads,
|
|
||||||
gradient_checkpointing=checkpointing,
|
|
||||||
use_cache=not checkpointing,
|
|
||||||
)
|
|
||||||
gpt = GPT2Model(gpt_config)
|
|
||||||
# Override the built in positional embeddings
|
|
||||||
del gpt.wpe
|
|
||||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
|
||||||
# Built-in token embeddings are unused.
|
|
||||||
del gpt.wte
|
|
||||||
|
|
||||||
mel_pos_emb = (
|
|
||||||
LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
|
|
||||||
if max_mel_seq_len != -1
|
|
||||||
else functools.partial(null_position_embeddings, dim=model_dim)
|
|
||||||
)
|
|
||||||
text_pos_emb = (
|
|
||||||
LearnedPositionEmbeddings(max_text_seq_len, model_dim)
|
|
||||||
if max_mel_seq_len != -1
|
|
||||||
else functools.partial(null_position_embeddings, dim=model_dim)
|
|
||||||
)
|
|
||||||
# gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True)
|
|
||||||
return gpt, mel_pos_emb, text_pos_emb, None, None
|
|
||||||
|
|
||||||
|
|
||||||
class GPT(nn.Module):
|
class GPT(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -127,13 +81,13 @@ class GPT(nn.Module):
|
||||||
self.mel_layer_pos_embedding,
|
self.mel_layer_pos_embedding,
|
||||||
self.text_layer_pos_embedding,
|
self.text_layer_pos_embedding,
|
||||||
) = build_hf_gpt_transformer(
|
) = build_hf_gpt_transformer(
|
||||||
layers,
|
layers=layers,
|
||||||
model_dim,
|
model_dim=model_dim,
|
||||||
heads,
|
heads=heads,
|
||||||
self.max_mel_tokens,
|
max_mel_seq_len=self.max_mel_tokens,
|
||||||
self.max_text_tokens,
|
max_text_seq_len=self.max_text_tokens,
|
||||||
self.max_prompt_tokens,
|
max_prompt_len=self.max_prompt_tokens,
|
||||||
checkpointing,
|
checkpointing=checkpointing,
|
||||||
)
|
)
|
||||||
if train_solo_embeddings:
|
if train_solo_embeddings:
|
||||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
|
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
|
||||||
|
|
Loading…
Reference in New Issue