mirror of https://github.com/coqui-ai/TTS.git
Make prompt embedding configurable
This commit is contained in:
parent
aa16da9194
commit
c182535e2a
|
@ -140,9 +140,6 @@ class GPT(nn.Module):
|
|||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||
self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
||||
|
||||
self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
||||
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
|
||||
|
||||
(
|
||||
self.gpt,
|
||||
self.mel_pos_embedding,
|
||||
|
@ -170,6 +167,7 @@ class GPT(nn.Module):
|
|||
self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)
|
||||
|
||||
if self.use_perceiver_resampler:
|
||||
# XTTS v2
|
||||
self.conditioning_perceiver = PerceiverResampler(
|
||||
dim=model_dim,
|
||||
depth=2,
|
||||
|
@ -180,6 +178,10 @@ class GPT(nn.Module):
|
|||
ff_mult=4,
|
||||
use_flash_attn=False,
|
||||
)
|
||||
else:
|
||||
# XTTS v1
|
||||
self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
||||
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
return {
|
||||
|
|
Loading…
Reference in New Issue