From 077a849b3b9ba9b4ab18a45fe19922bd02979065 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 25 Oct 2023 18:50:35 -0300 Subject: [PATCH] Implement most similar ref training approach --- TTS/tts/layers/xtts/trainer/dataset.py | 27 ++++++++++++++++------ TTS/tts/layers/xtts/trainer/gpt_trainer.py | 1 + 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py index 41401fd6..e37387ce 100644 --- a/TTS/tts/layers/xtts/trainer/dataset.py +++ b/TTS/tts/layers/xtts/trainer/dataset.py @@ -88,6 +88,7 @@ class XTTSDataset(torch.utils.data.Dataset): self.sample_rate = sample_rate self.max_wav_len = model_args.max_wav_length self.max_text_len = model_args.max_text_length + self.use_masking_gt_as_prompt = model_args.use_masking_gt_as_prompt assert self.max_wav_len is not None and self.max_text_len is not None self.samples = samples @@ -109,7 +110,7 @@ class XTTSDataset(torch.utils.data.Dataset): try: tseq, _, wav, _, _, _ = self.load_item(sample) except: - pass + continue # Basically, this audio file is nonexistent or too long to be supported by the dataset. if ( wav is None @@ -140,10 +141,18 @@ class XTTSDataset(torch.utils.data.Dataset): # Ultra short clips are also useless (and can cause problems within some models). raise ValueError - # get a slice from GT to condition the model - cond, cond_len, cond_idxs = get_prompt_slice( - audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval - ) + if self.use_masking_gt_as_prompt: + # get a slice from GT to condition the model + cond, cond_len, cond_idxs = get_prompt_slice( + audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval + ) + else: + ref_sample = sample["reference_path"] if "reference_path" in sample and sample["reference_path"] is not None else audiopath + cond, cond_len, cond_idxs = get_prompt_slice( + ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval + ) + cond_idxs = torch.nan + cond_len = torch.nan return tseq, audiopath, wav, cond, cond_len, cond_idxs @@ -199,8 +208,8 @@ class XTTSDataset(torch.utils.data.Dataset): "wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long), "filenames": audiopath, "conditioning": cond.unsqueeze(1), - "cond_lens": torch.tensor(cond_len, dtype=torch.long), - "cond_idxs": torch.tensor(cond_idxs), + "cond_lens": torch.tensor(cond_len, dtype=torch.long) if cond_len is not torch.nan else torch.tensor([cond_len]), + "cond_idxs": torch.tensor(cond_idxs) if cond_idxs is not torch.nan else torch.tensor([cond_len]), } return res @@ -221,6 +230,10 @@ class XTTSDataset(torch.utils.data.Dataset): batch["conditioning"] = torch.stack(batch["conditioning"]) batch["cond_lens"] = torch.stack(batch["cond_lens"]) batch["cond_idxs"] = torch.stack(batch["cond_idxs"]) + if torch.any(batch["cond_idxs"].isnan()): + batch["cond_lens"] = None + batch["cond_idxs"] = None + max_text_len = batch["text_lengths"].max() max_wav_len = batch["wav_lengths"].max() diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index e93063fa..afb91f8f 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -52,6 +52,7 @@ class GPTArgs(XttsArgs): xtts_checkpoint: str = "" gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model vocoder: str = "" # overide vocoder key on the config to avoid json write issues + use_masking_gt_as_prompt: bool = True def callback_clearml_load_save(operation_type, model_info):