Add reproducible evaluation

This commit is contained in:
Edresson Casanova 2023-10-13 15:49:37 -03:00
parent 40a4e631ea
commit 47d613df3a
3 changed files with 47 additions and 28 deletions

View File

@ -63,35 +63,48 @@ def load_audio(audiopath, sampling_rate):
return audio
class XTTSDataset(torch.utils.data.Dataset):
def __init__(self, config, samples, tokenizer, sample_rate):
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
self.config = config
model_args = config.model_args
self.failed_samples = set()
self.debug_failures = model_args.debug_loading_failures
self.max_conditioning_length = model_args.max_conditioning_length
self.min_conditioning_length = model_args.min_conditioning_length
# self.samples = []
# cache the samples and added type "0" for all samples
# ToDo: find a better way to deal with type
# for item in samples:
# self.samples.append([item['audio_file'], item["text"], 0])
self.samples = samples
random.seed(config.training_seed)
# random.shuffle(self.samples)
random.shuffle(self.samples)
# order by language
self.samples = key_samples_by_col(self.samples, "language")
print(" > Sampling by language:", self.samples.keys())
# use always the output sampling rate to load in the highest quality
self.is_eval = is_eval
self.tokenizer = tokenizer
self.sample_rate = sample_rate
self.max_wav_len = model_args.max_wav_length
self.max_text_len = model_args.max_text_length
assert self.max_wav_len is not None and self.max_text_len is not None
# load specific vocabulary
self.tokenizer = tokenizer
self.samples = samples
if not is_eval:
random.seed(config.training_seed)
# random.shuffle(self.samples)
random.shuffle(self.samples)
# order by language
self.samples = key_samples_by_col(self.samples, "language")
print(" > Sampling by language:", self.samples.keys())
else:
# for evaluation load and check samples that are corrupted to ensures the reproducibility
self.check_eval_samples()
def check_eval_samples(self):
print("Filtering invalid eval samples!!")
new_samples = []
for sample in self.samples:
try:
tseq, _, wav, _, _, _ = self.load_item(sample)
except:
pass
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
if wav is None or \
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
continue
new_samples.append(sample)
self.samples = new_samples
print("Total eval samples after filtering:", len(self.samples))
def get_text(self, text, lang):
tokens = self.tokenizer.encode(text, lang)
@ -118,13 +131,17 @@ class XTTSDataset(torch.utils.data.Dataset):
return tseq, audiopath, wav, cond, cond_len, cond_idxs
def __getitem__(self, index):
# select a random language
lang = random.choice(list(self.samples.keys()))
# select random sample
index = random.randint(0, len(self.samples[lang]) - 1)
sample = self.samples[lang][index]
# a unique id for each sampel to deal with fails
sample_id = lang+"_"+str(index)
if self.is_eval:
sample = self.samples[index]
sample_id = str(index)
else:
# select a random language
lang = random.choice(list(self.samples.keys()))
# select random sample
index = random.randint(0, len(self.samples[lang]) - 1)
sample = self.samples[lang][index]
# a unique id for each sampel to deal with fails
sample_id = lang+"_"+str(index)
# ignore samples that we already know that is not valid ones
if sample_id in self.failed_samples:
@ -167,11 +184,14 @@ class XTTSDataset(torch.utils.data.Dataset):
return res
def __len__(self):
if self.is_eval:
return len(self.samples)
return sum([len(v) for v in self.samples.values()])
def collate_fn(self, batch):
# convert list of dicts to dict of lists
B = len(batch)
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
# stack for features that already have the same shape
@ -198,5 +218,4 @@ class XTTSDataset(torch.utils.data.Dataset):
batch["wav"] = wav_padded
batch["padded_text"] = text_padded
return batch

View File

@ -327,7 +327,7 @@ class GPTTrainer(BaseTTS):
else:
# Todo: remove the randomness of dataset when it is eval
# init dataloader
dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate)
dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate, is_eval)
# wait all the DDP process to be ready
if num_gpus > 1:

View File

@ -355,7 +355,7 @@ if __name__ == "__main__":
DASHBOARD_LOGGER = "tensorboard"
LOGGER_URI = None
RESTORE_PATH = None
BATCH_SIZE = 2
BATCH_SIZE = 10
GRAD_ACUMM_STEPS = 1
NUM_LOADERS = 1