diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py index 9736ae6c..3ac22b5d 100644 --- a/TTS/tts/layers/xtts/trainer/dataset.py +++ b/TTS/tts/layers/xtts/trainer/dataset.py @@ -63,35 +63,48 @@ def load_audio(audiopath, sampling_rate): return audio class XTTSDataset(torch.utils.data.Dataset): - def __init__(self, config, samples, tokenizer, sample_rate): + def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False): self.config = config model_args = config.model_args self.failed_samples = set() self.debug_failures = model_args.debug_loading_failures self.max_conditioning_length = model_args.max_conditioning_length self.min_conditioning_length = model_args.min_conditioning_length - - # self.samples = [] - # cache the samples and added type "0" for all samples - # ToDo: find a better way to deal with type - # for item in samples: - # self.samples.append([item['audio_file'], item["text"], 0]) - self.samples = samples - random.seed(config.training_seed) - # random.shuffle(self.samples) - random.shuffle(self.samples) - # order by language - self.samples = key_samples_by_col(self.samples, "language") - print(" > Sampling by language:", self.samples.keys()) - - # use always the output sampling rate to load in the highest quality + self.is_eval = is_eval + self.tokenizer = tokenizer self.sample_rate = sample_rate self.max_wav_len = model_args.max_wav_length self.max_text_len = model_args.max_text_length assert self.max_wav_len is not None and self.max_text_len is not None - # load specific vocabulary - self.tokenizer = tokenizer + self.samples = samples + if not is_eval: + random.seed(config.training_seed) + # random.shuffle(self.samples) + random.shuffle(self.samples) + # order by language + self.samples = key_samples_by_col(self.samples, "language") + print(" > Sampling by language:", self.samples.keys()) + else: + # for evaluation load and check samples that are corrupted to ensures the reproducibility + self.check_eval_samples() + + def check_eval_samples(self): + print("Filtering invalid eval samples!!") + new_samples = [] + for sample in self.samples: + try: + tseq, _, wav, _, _, _ = self.load_item(sample) + except: + pass + # Basically, this audio file is nonexistent or too long to be supported by the dataset. + if wav is None or \ + (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \ + (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): + continue + new_samples.append(sample) + self.samples = new_samples + print("Total eval samples after filtering:", len(self.samples)) def get_text(self, text, lang): tokens = self.tokenizer.encode(text, lang) @@ -118,13 +131,17 @@ class XTTSDataset(torch.utils.data.Dataset): return tseq, audiopath, wav, cond, cond_len, cond_idxs def __getitem__(self, index): - # select a random language - lang = random.choice(list(self.samples.keys())) - # select random sample - index = random.randint(0, len(self.samples[lang]) - 1) - sample = self.samples[lang][index] - # a unique id for each sampel to deal with fails - sample_id = lang+"_"+str(index) + if self.is_eval: + sample = self.samples[index] + sample_id = str(index) + else: + # select a random language + lang = random.choice(list(self.samples.keys())) + # select random sample + index = random.randint(0, len(self.samples[lang]) - 1) + sample = self.samples[lang][index] + # a unique id for each sampel to deal with fails + sample_id = lang+"_"+str(index) # ignore samples that we already know that is not valid ones if sample_id in self.failed_samples: @@ -167,11 +184,14 @@ class XTTSDataset(torch.utils.data.Dataset): return res def __len__(self): + if self.is_eval: + return len(self.samples) return sum([len(v) for v in self.samples.values()]) def collate_fn(self, batch): # convert list of dicts to dict of lists B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} # stack for features that already have the same shape @@ -198,5 +218,4 @@ class XTTSDataset(torch.utils.data.Dataset): batch["wav"] = wav_padded batch["padded_text"] = text_padded - return batch diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 6494f336..71cfd6e4 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -327,7 +327,7 @@ class GPTTrainer(BaseTTS): else: # Todo: remove the randomness of dataset when it is eval # init dataloader - dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate) + dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate, is_eval) # wait all the DDP process to be ready if num_gpus > 1: diff --git a/recipes/multilingual/xtts_v1/train_xtts.py b/recipes/multilingual/xtts_v1/train_xtts.py index fc2b5d8a..4e987a4f 100644 --- a/recipes/multilingual/xtts_v1/train_xtts.py +++ b/recipes/multilingual/xtts_v1/train_xtts.py @@ -355,7 +355,7 @@ if __name__ == "__main__": DASHBOARD_LOGGER = "tensorboard" LOGGER_URI = None RESTORE_PATH = None - BATCH_SIZE = 2 + BATCH_SIZE = 10 GRAD_ACUMM_STEPS = 1 NUM_LOADERS = 1