mirror of https://github.com/coqui-ai/TTS.git
Bug fix masking with XTTS perceiver
This commit is contained in:
parent
5df8f76b0c
commit
a032d9877b
|
@ -403,8 +403,11 @@ class GPT(nn.Module):
|
|||
|
||||
if cond_idxs is not None:
|
||||
# recompute cond idxs for mel lengths
|
||||
for idx, l in enumerate(code_lengths):
|
||||
cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len
|
||||
for idx in range(cond_idxs.size(0)):
|
||||
if self.use_perceiver_resampler:
|
||||
cond_idxs[idx] = cond_idxs[idx] // self.perceiver_cond_length_compression
|
||||
else:
|
||||
cond_idxs[idx] = cond_idxs[idx] // self.code_stride_len
|
||||
|
||||
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
|
||||
max_mel_len = code_lengths.max()
|
||||
|
@ -512,7 +515,7 @@ class GPT(nn.Module):
|
|||
prompt=cond_latents,
|
||||
get_attns=return_attentions,
|
||||
return_latent=return_latent,
|
||||
attn_mask_cond=attn_mask_cond if not self.use_perceiver_resampler else None,
|
||||
attn_mask_cond=attn_mask_cond,
|
||||
attn_mask_text=attn_mask_text,
|
||||
attn_mask_mel=attn_mask_mel,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue