Make prompt embedding configurable

This commit is contained in:
Eren G??lge 2023-11-04 13:40:28 +01:00
parent aa16da9194
commit c182535e2a
1 changed files with 5 additions and 3 deletions

View File

@ -140,9 +140,6 @@ class GPT(nn.Module):
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
(
self.gpt,
self.mel_pos_embedding,
@ -170,6 +167,7 @@ class GPT(nn.Module):
self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)
if self.use_perceiver_resampler:
# XTTS v2
self.conditioning_perceiver = PerceiverResampler(
dim=model_dim,
depth=2,
@ -180,6 +178,10 @@ class GPT(nn.Module):
ff_mult=4,
use_flash_attn=False,
)
else:
# XTTS v1
self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
def get_grad_norm_parameter_groups(self):
return {