diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index 0fe056a2..6eee8481 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -409,6 +409,11 @@ class GPT(nn.Module): else: cond_idxs[idx] = cond_idxs[idx] // self.code_stride_len + # ensure that the cond_mel does not have padding + if cond_lens is not None and cond_idxs is None: + min_cond_len = torch.min(cond_lens) + cond_mels = cond_mels[:, :, :, :min_cond_len] + # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes. max_mel_len = code_lengths.max() @@ -497,9 +502,6 @@ class GPT(nn.Module): # Compute speech conditioning input if cond_latents is None: - if cond_lens is not None and cond_idxs is None: - min_cond_len = torch.min(cond_lens) - cond_mels = cond_mels[:, :, :, :min_cond_len] cond_latents = self.get_style_emb(cond_mels).transpose(1, 2) # Get logits