mirror of https://github.com/coqui-ai/TTS.git
Make prompt embedding configurable
This commit is contained in:
parent
aa16da9194
commit
c182535e2a
|
@ -140,9 +140,6 @@ class GPT(nn.Module):
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||||
self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
||||||
|
|
||||||
self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
|
||||||
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
|
|
||||||
|
|
||||||
(
|
(
|
||||||
self.gpt,
|
self.gpt,
|
||||||
self.mel_pos_embedding,
|
self.mel_pos_embedding,
|
||||||
|
@ -170,6 +167,7 @@ class GPT(nn.Module):
|
||||||
self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)
|
self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)
|
||||||
|
|
||||||
if self.use_perceiver_resampler:
|
if self.use_perceiver_resampler:
|
||||||
|
# XTTS v2
|
||||||
self.conditioning_perceiver = PerceiverResampler(
|
self.conditioning_perceiver = PerceiverResampler(
|
||||||
dim=model_dim,
|
dim=model_dim,
|
||||||
depth=2,
|
depth=2,
|
||||||
|
@ -180,6 +178,10 @@ class GPT(nn.Module):
|
||||||
ff_mult=4,
|
ff_mult=4,
|
||||||
use_flash_attn=False,
|
use_flash_attn=False,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# XTTS v1
|
||||||
|
self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
||||||
|
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
return {
|
return {
|
||||||
|
|
Loading…
Reference in New Issue