diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index e7c0a41a..dfd7774e 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -450,7 +450,7 @@ class GPT(nn.Module): ) if cond_idxs is not None: - for idx, r in enumerate(cond_idxs.squeeze()): + for idx, r in enumerate(cond_idxs): l = r[1] - r[0] attn_mask_cond[idx, l:] = 0.0