Bug fix masking with XTTS perceiver

This commit is contained in:
Edresson Casanova 2023-11-01 14:07:39 -03:00 committed by Eren G??lge
parent 5df8f76b0c
commit a032d9877b
1 changed files with 6 additions and 3 deletions

View File

@ -403,8 +403,11 @@ class GPT(nn.Module):
if cond_idxs is not None:
# recompute cond idxs for mel lengths
for idx, l in enumerate(code_lengths):
cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len
for idx in range(cond_idxs.size(0)):
if self.use_perceiver_resampler:
cond_idxs[idx] = cond_idxs[idx] // self.perceiver_cond_length_compression
else:
cond_idxs[idx] = cond_idxs[idx] // self.code_stride_len
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
max_mel_len = code_lengths.max()
@ -512,7 +515,7 @@ class GPT(nn.Module):
prompt=cond_latents,
get_attns=return_attentions,
return_latent=return_latent,
attn_mask_cond=attn_mask_cond if not self.use_perceiver_resampler else None,
attn_mask_cond=attn_mask_cond,
attn_mask_text=attn_mask_text,
attn_mask_mel=attn_mask_mel,
)