From a032d9877b3f08eaaac696565d92a549bbd5a259 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 1 Nov 2023 14:07:39 -0300 Subject: [PATCH] Bug fix masking with XTTS perceiver --- TTS/tts/layers/xtts/gpt.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index c3477d52..0fe056a2 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -403,8 +403,11 @@ class GPT(nn.Module): if cond_idxs is not None: # recompute cond idxs for mel lengths - for idx, l in enumerate(code_lengths): - cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len + for idx in range(cond_idxs.size(0)): + if self.use_perceiver_resampler: + cond_idxs[idx] = cond_idxs[idx] // self.perceiver_cond_length_compression + else: + cond_idxs[idx] = cond_idxs[idx] // self.code_stride_len # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes. max_mel_len = code_lengths.max() @@ -512,7 +515,7 @@ class GPT(nn.Module): prompt=cond_latents, get_attns=return_attentions, return_latent=return_latent, - attn_mask_cond=attn_mask_cond if not self.use_perceiver_resampler else None, + attn_mask_cond=attn_mask_cond, attn_mask_text=attn_mask_text, attn_mask_mel=attn_mask_mel, )