From cff8542012fba4bf748c864313c034484e9cf039 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 3 Nov 2023 13:26:13 -0300 Subject: [PATCH] Bug Fix on XTTS v2.0 training --- TTS/tts/layers/xtts/gpt.py | 6 +++--- TTS/tts/layers/xtts/trainer/dataset.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index 6eee8481..da5d8995 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -410,9 +410,9 @@ class GPT(nn.Module): cond_idxs[idx] = cond_idxs[idx] // self.code_stride_len # ensure that the cond_mel does not have padding - if cond_lens is not None and cond_idxs is None: - min_cond_len = torch.min(cond_lens) - cond_mels = cond_mels[:, :, :, :min_cond_len] + # if cond_lens is not None and cond_idxs is None: + # min_cond_len = torch.min(cond_lens) + # cond_mels = cond_mels[:, :, :, :min_cond_len] # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes. max_mel_len = code_lengths.max() diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py index bef930d1..cde43287 100644 --- a/TTS/tts/layers/xtts/trainer/dataset.py +++ b/TTS/tts/layers/xtts/trainer/dataset.py @@ -211,7 +211,7 @@ class XTTSDataset(torch.utils.data.Dataset): "filenames": audiopath, "conditioning": cond.unsqueeze(1), "cond_lens": torch.tensor(cond_len, dtype=torch.long) if cond_len is not torch.nan else torch.tensor([cond_len]), - "cond_idxs": torch.tensor(cond_idxs) if cond_idxs is not torch.nan else torch.tensor([cond_len]), + "cond_idxs": torch.tensor(cond_idxs) if cond_idxs is not torch.nan else torch.tensor([cond_idxs]), } return res @@ -234,7 +234,7 @@ class XTTSDataset(torch.utils.data.Dataset): batch["cond_idxs"] = torch.stack(batch["cond_idxs"]) if torch.any(batch["cond_idxs"].isnan()): - batch["cond_lens"] = None + batch["cond_idxs"] = None if torch.any(batch["cond_lens"].isnan()): batch["cond_lens"] = None