mirror of https://github.com/coqui-ai/TTS.git
Bug Fix on XTTS v2.0 training
This commit is contained in:
parent
32796fdfc1
commit
cff8542012
|
@ -410,9 +410,9 @@ class GPT(nn.Module):
|
|||
cond_idxs[idx] = cond_idxs[idx] // self.code_stride_len
|
||||
|
||||
# ensure that the cond_mel does not have padding
|
||||
if cond_lens is not None and cond_idxs is None:
|
||||
min_cond_len = torch.min(cond_lens)
|
||||
cond_mels = cond_mels[:, :, :, :min_cond_len]
|
||||
# if cond_lens is not None and cond_idxs is None:
|
||||
# min_cond_len = torch.min(cond_lens)
|
||||
# cond_mels = cond_mels[:, :, :, :min_cond_len]
|
||||
|
||||
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
|
||||
max_mel_len = code_lengths.max()
|
||||
|
|
|
@ -211,7 +211,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
"filenames": audiopath,
|
||||
"conditioning": cond.unsqueeze(1),
|
||||
"cond_lens": torch.tensor(cond_len, dtype=torch.long) if cond_len is not torch.nan else torch.tensor([cond_len]),
|
||||
"cond_idxs": torch.tensor(cond_idxs) if cond_idxs is not torch.nan else torch.tensor([cond_len]),
|
||||
"cond_idxs": torch.tensor(cond_idxs) if cond_idxs is not torch.nan else torch.tensor([cond_idxs]),
|
||||
}
|
||||
return res
|
||||
|
||||
|
@ -234,7 +234,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
batch["cond_idxs"] = torch.stack(batch["cond_idxs"])
|
||||
|
||||
if torch.any(batch["cond_idxs"].isnan()):
|
||||
batch["cond_lens"] = None
|
||||
batch["cond_idxs"] = None
|
||||
|
||||
if torch.any(batch["cond_lens"].isnan()):
|
||||
batch["cond_lens"] = None
|
||||
|
|
Loading…
Reference in New Issue