Bug fix on gpt forward

This commit is contained in:
Edresson Casanova 2023-11-01 14:22:58 -03:00 committed by Eren G??lge
parent a032d9877b
commit 32796fdfc1
1 changed files with 5 additions and 3 deletions

View File

@ -409,6 +409,11 @@ class GPT(nn.Module):
else:
cond_idxs[idx] = cond_idxs[idx] // self.code_stride_len
# ensure that the cond_mel does not have padding
if cond_lens is not None and cond_idxs is None:
min_cond_len = torch.min(cond_lens)
cond_mels = cond_mels[:, :, :, :min_cond_len]
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
max_mel_len = code_lengths.max()
@ -497,9 +502,6 @@ class GPT(nn.Module):
# Compute speech conditioning input
if cond_latents is None:
if cond_lens is not None and cond_idxs is None:
min_cond_len = torch.min(cond_lens)
cond_mels = cond_mels[:, :, :, :min_cond_len]
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)
# Get logits