mirror of https://github.com/coqui-ai/TTS.git
Implement most similar ref training approach
This commit is contained in:
parent
38f6f8f0bb
commit
077a849b3b
|
@ -88,6 +88,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.max_wav_len = model_args.max_wav_length
|
self.max_wav_len = model_args.max_wav_length
|
||||||
self.max_text_len = model_args.max_text_length
|
self.max_text_len = model_args.max_text_length
|
||||||
|
self.use_masking_gt_as_prompt = model_args.use_masking_gt_as_prompt
|
||||||
assert self.max_wav_len is not None and self.max_text_len is not None
|
assert self.max_wav_len is not None and self.max_text_len is not None
|
||||||
|
|
||||||
self.samples = samples
|
self.samples = samples
|
||||||
|
@ -109,7 +110,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
try:
|
try:
|
||||||
tseq, _, wav, _, _, _ = self.load_item(sample)
|
tseq, _, wav, _, _, _ = self.load_item(sample)
|
||||||
except:
|
except:
|
||||||
pass
|
continue
|
||||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||||
if (
|
if (
|
||||||
wav is None
|
wav is None
|
||||||
|
@ -140,10 +141,18 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
# Ultra short clips are also useless (and can cause problems within some models).
|
# Ultra short clips are also useless (and can cause problems within some models).
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
# get a slice from GT to condition the model
|
if self.use_masking_gt_as_prompt:
|
||||||
cond, cond_len, cond_idxs = get_prompt_slice(
|
# get a slice from GT to condition the model
|
||||||
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
cond, cond_len, cond_idxs = get_prompt_slice(
|
||||||
)
|
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ref_sample = sample["reference_path"] if "reference_path" in sample and sample["reference_path"] is not None else audiopath
|
||||||
|
cond, cond_len, cond_idxs = get_prompt_slice(
|
||||||
|
ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
||||||
|
)
|
||||||
|
cond_idxs = torch.nan
|
||||||
|
cond_len = torch.nan
|
||||||
|
|
||||||
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
||||||
|
|
||||||
|
@ -199,8 +208,8 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
"wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
|
"wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
|
||||||
"filenames": audiopath,
|
"filenames": audiopath,
|
||||||
"conditioning": cond.unsqueeze(1),
|
"conditioning": cond.unsqueeze(1),
|
||||||
"cond_lens": torch.tensor(cond_len, dtype=torch.long),
|
"cond_lens": torch.tensor(cond_len, dtype=torch.long) if cond_len is not torch.nan else torch.tensor([cond_len]),
|
||||||
"cond_idxs": torch.tensor(cond_idxs),
|
"cond_idxs": torch.tensor(cond_idxs) if cond_idxs is not torch.nan else torch.tensor([cond_len]),
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -221,6 +230,10 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
batch["conditioning"] = torch.stack(batch["conditioning"])
|
batch["conditioning"] = torch.stack(batch["conditioning"])
|
||||||
batch["cond_lens"] = torch.stack(batch["cond_lens"])
|
batch["cond_lens"] = torch.stack(batch["cond_lens"])
|
||||||
batch["cond_idxs"] = torch.stack(batch["cond_idxs"])
|
batch["cond_idxs"] = torch.stack(batch["cond_idxs"])
|
||||||
|
if torch.any(batch["cond_idxs"].isnan()):
|
||||||
|
batch["cond_lens"] = None
|
||||||
|
batch["cond_idxs"] = None
|
||||||
|
|
||||||
max_text_len = batch["text_lengths"].max()
|
max_text_len = batch["text_lengths"].max()
|
||||||
max_wav_len = batch["wav_lengths"].max()
|
max_wav_len = batch["wav_lengths"].max()
|
||||||
|
|
||||||
|
|
|
@ -52,6 +52,7 @@ class GPTArgs(XttsArgs):
|
||||||
xtts_checkpoint: str = ""
|
xtts_checkpoint: str = ""
|
||||||
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
|
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
|
||||||
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
|
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
|
||||||
|
use_masking_gt_as_prompt: bool = True
|
||||||
|
|
||||||
|
|
||||||
def callback_clearml_load_save(operation_type, model_info):
|
def callback_clearml_load_save(operation_type, model_info):
|
||||||
|
|
Loading…
Reference in New Issue