diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index c3477d52..0fe056a2 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -403,8 +403,11 @@ class GPT(nn.Module): if cond_idxs is not None: # recompute cond idxs for mel lengths - for idx, l in enumerate(code_lengths): - cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len + for idx in range(cond_idxs.size(0)): + if self.use_perceiver_resampler: + cond_idxs[idx] = cond_idxs[idx] // self.perceiver_cond_length_compression + else: + cond_idxs[idx] = cond_idxs[idx] // self.code_stride_len # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes. max_mel_len = code_lengths.max() @@ -512,7 +515,7 @@ class GPT(nn.Module): prompt=cond_latents, get_attns=return_attentions, return_latent=return_latent, - attn_mask_cond=attn_mask_cond if not self.use_perceiver_resampler else None, + attn_mask_cond=attn_mask_cond, attn_mask_text=attn_mask_text, attn_mask_mel=attn_mask_mel, )