mirror of https://github.com/coqui-ai/TTS.git
Bug fix masking with XTTS perceiver
This commit is contained in:
parent
5df8f76b0c
commit
a032d9877b
|
@ -403,8 +403,11 @@ class GPT(nn.Module):
|
||||||
|
|
||||||
if cond_idxs is not None:
|
if cond_idxs is not None:
|
||||||
# recompute cond idxs for mel lengths
|
# recompute cond idxs for mel lengths
|
||||||
for idx, l in enumerate(code_lengths):
|
for idx in range(cond_idxs.size(0)):
|
||||||
cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len
|
if self.use_perceiver_resampler:
|
||||||
|
cond_idxs[idx] = cond_idxs[idx] // self.perceiver_cond_length_compression
|
||||||
|
else:
|
||||||
|
cond_idxs[idx] = cond_idxs[idx] // self.code_stride_len
|
||||||
|
|
||||||
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
|
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
|
||||||
max_mel_len = code_lengths.max()
|
max_mel_len = code_lengths.max()
|
||||||
|
@ -512,7 +515,7 @@ class GPT(nn.Module):
|
||||||
prompt=cond_latents,
|
prompt=cond_latents,
|
||||||
get_attns=return_attentions,
|
get_attns=return_attentions,
|
||||||
return_latent=return_latent,
|
return_latent=return_latent,
|
||||||
attn_mask_cond=attn_mask_cond if not self.use_perceiver_resampler else None,
|
attn_mask_cond=attn_mask_cond,
|
||||||
attn_mask_text=attn_mask_text,
|
attn_mask_text=attn_mask_text,
|
||||||
attn_mask_mel=attn_mask_mel,
|
attn_mask_mel=attn_mask_mel,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue