From c182535e2a2957c6c819c52da3a1f2d6d372a842 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Sat, 4 Nov 2023 13:40:28 +0100 Subject: [PATCH] Make prompt embedding configurable --- TTS/tts/layers/xtts/gpt.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index da5d8995..51a64b99 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -140,9 +140,6 @@ class GPT(nn.Module): self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim) - self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim) - self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim) - ( self.gpt, self.mel_pos_embedding, @@ -170,6 +167,7 @@ class GPT(nn.Module): self.mel_head = nn.Linear(model_dim, self.num_audio_tokens) if self.use_perceiver_resampler: + # XTTS v2 self.conditioning_perceiver = PerceiverResampler( dim=model_dim, depth=2, @@ -180,6 +178,10 @@ class GPT(nn.Module): ff_mult=4, use_flash_attn=False, ) + else: + # XTTS v1 + self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim) + self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim) def get_grad_norm_parameter_groups(self): return {