Bug Fix on XTTS v2.0 training

This commit is contained in:
Edresson Casanova 2023-11-03 13:26:13 -03:00 committed by Eren G??lge
parent 32796fdfc1
commit cff8542012
2 changed files with 5 additions and 5 deletions

View File

@ -410,9 +410,9 @@ class GPT(nn.Module):
cond_idxs[idx] = cond_idxs[idx] // self.code_stride_len
# ensure that the cond_mel does not have padding
if cond_lens is not None and cond_idxs is None:
min_cond_len = torch.min(cond_lens)
cond_mels = cond_mels[:, :, :, :min_cond_len]
# if cond_lens is not None and cond_idxs is None:
# min_cond_len = torch.min(cond_lens)
# cond_mels = cond_mels[:, :, :, :min_cond_len]
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
max_mel_len = code_lengths.max()

View File

@ -211,7 +211,7 @@ class XTTSDataset(torch.utils.data.Dataset):
"filenames": audiopath,
"conditioning": cond.unsqueeze(1),
"cond_lens": torch.tensor(cond_len, dtype=torch.long) if cond_len is not torch.nan else torch.tensor([cond_len]),
"cond_idxs": torch.tensor(cond_idxs) if cond_idxs is not torch.nan else torch.tensor([cond_len]),
"cond_idxs": torch.tensor(cond_idxs) if cond_idxs is not torch.nan else torch.tensor([cond_idxs]),
}
return res
@ -234,7 +234,7 @@ class XTTSDataset(torch.utils.data.Dataset):
batch["cond_idxs"] = torch.stack(batch["cond_idxs"])
if torch.any(batch["cond_idxs"].isnan()):
batch["cond_lens"] = None
batch["cond_idxs"] = None
if torch.any(batch["cond_lens"].isnan()):
batch["cond_lens"] = None