diff --git a/TTS/tts/layers/tortoise/autoregressive.py b/TTS/tts/layers/tortoise/autoregressive.py index 3463e63b..19c1adc0 100644 --- a/TTS/tts/layers/tortoise/autoregressive.py +++ b/TTS/tts/layers/tortoise/autoregressive.py @@ -217,7 +217,15 @@ class LearnedPositionEmbeddings(nn.Module): return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) -def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): +def build_hf_gpt_transformer( + layers: int, + model_dim: int, + heads: int, + max_mel_seq_len: int, + max_text_seq_len: int, + checkpointing: bool, + max_prompt_len: int = 0, +): """ GPT-2 implemented by the HuggingFace library. """ @@ -225,8 +233,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text gpt_config = GPT2Config( vocab_size=256, # Unused. - n_positions=max_mel_seq_len + max_text_seq_len, - n_ctx=max_mel_seq_len + max_text_seq_len, + n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len, + n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len, n_embd=model_dim, n_layer=layers, n_head=heads, @@ -239,13 +247,18 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) # Built-in token embeddings are unused. del gpt.wte - return ( - gpt, - LearnedPositionEmbeddings(max_mel_seq_len, model_dim), - LearnedPositionEmbeddings(max_text_seq_len, model_dim), - None, - None, + + mel_pos_emb = ( + LearnedPositionEmbeddings(max_mel_seq_len, model_dim) + if max_mel_seq_len != -1 + else functools.partial(null_position_embeddings, dim=model_dim) ) + text_pos_emb = ( + LearnedPositionEmbeddings(max_text_seq_len, model_dim) + if max_mel_seq_len != -1 + else functools.partial(null_position_embeddings, dim=model_dim) + ) + return gpt, mel_pos_emb, text_pos_emb, None, None class MelEncoder(nn.Module): @@ -339,12 +352,12 @@ class UnifiedVoice(nn.Module): self.mel_layer_pos_embedding, self.text_layer_pos_embedding, ) = build_hf_gpt_transformer( - layers, - model_dim, - heads, - self.max_mel_tokens + 2 + self.max_conditioning_inputs, - self.max_text_tokens + 2, - checkpointing, + layers=layers, + model_dim=model_dim, + heads=heads, + max_mel_seq_len=self.max_mel_tokens + 2 + self.max_conditioning_inputs, + max_text_seq_len=self.max_text_tokens + 2, + checkpointing=checkpointing, ) if train_solo_embeddings: self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index f9328761..899522e0 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -1,6 +1,5 @@ # ported from: https://github.com/neonbjb/tortoise-tts -import functools import random import torch @@ -8,61 +7,16 @@ import torch.nn as nn import torch.nn.functional as F from transformers import GPT2Config -from TTS.tts.layers.tortoise.autoregressive import LearnedPositionEmbeddings, _prepare_attention_mask_for_generation +from TTS.tts.layers.tortoise.autoregressive import ( + LearnedPositionEmbeddings, + _prepare_attention_mask_for_generation, + build_hf_gpt_transformer, +) from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler -def null_position_embeddings(range, dim): - return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) - - -def build_hf_gpt_transformer( - layers, - model_dim, - heads, - max_mel_seq_len, - max_text_seq_len, - max_prompt_len, - checkpointing, -): - """ - GPT-2 implemented by the HuggingFace library. - """ - from transformers import GPT2Config, GPT2Model - - gpt_config = GPT2Config( - vocab_size=256, # Unused. - n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len, - n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing, - ) - gpt = GPT2Model(gpt_config) - # Override the built in positional embeddings - del gpt.wpe - gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) - # Built-in token embeddings are unused. - del gpt.wte - - mel_pos_emb = ( - LearnedPositionEmbeddings(max_mel_seq_len, model_dim) - if max_mel_seq_len != -1 - else functools.partial(null_position_embeddings, dim=model_dim) - ) - text_pos_emb = ( - LearnedPositionEmbeddings(max_text_seq_len, model_dim) - if max_mel_seq_len != -1 - else functools.partial(null_position_embeddings, dim=model_dim) - ) - # gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True) - return gpt, mel_pos_emb, text_pos_emb, None, None - - class GPT(nn.Module): def __init__( self, @@ -127,13 +81,13 @@ class GPT(nn.Module): self.mel_layer_pos_embedding, self.text_layer_pos_embedding, ) = build_hf_gpt_transformer( - layers, - model_dim, - heads, - self.max_mel_tokens, - self.max_text_tokens, - self.max_prompt_tokens, - checkpointing, + layers=layers, + model_dim=model_dim, + heads=heads, + max_mel_seq_len=self.max_mel_tokens, + max_text_seq_len=self.max_text_tokens, + max_prompt_len=self.max_prompt_tokens, + checkpointing=checkpointing, ) if train_solo_embeddings: self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)