mirror of https://github.com/coqui-ai/TTS.git
Add reproducible evaluation
This commit is contained in:
parent
40a4e631ea
commit
47d613df3a
|
@ -63,35 +63,48 @@ def load_audio(audiopath, sampling_rate):
|
|||
return audio
|
||||
|
||||
class XTTSDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, config, samples, tokenizer, sample_rate):
|
||||
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
|
||||
self.config = config
|
||||
model_args = config.model_args
|
||||
self.failed_samples = set()
|
||||
self.debug_failures = model_args.debug_loading_failures
|
||||
self.max_conditioning_length = model_args.max_conditioning_length
|
||||
self.min_conditioning_length = model_args.min_conditioning_length
|
||||
|
||||
# self.samples = []
|
||||
# cache the samples and added type "0" for all samples
|
||||
# ToDo: find a better way to deal with type
|
||||
# for item in samples:
|
||||
# self.samples.append([item['audio_file'], item["text"], 0])
|
||||
self.samples = samples
|
||||
random.seed(config.training_seed)
|
||||
# random.shuffle(self.samples)
|
||||
random.shuffle(self.samples)
|
||||
# order by language
|
||||
self.samples = key_samples_by_col(self.samples, "language")
|
||||
print(" > Sampling by language:", self.samples.keys())
|
||||
|
||||
# use always the output sampling rate to load in the highest quality
|
||||
self.is_eval = is_eval
|
||||
self.tokenizer = tokenizer
|
||||
self.sample_rate = sample_rate
|
||||
self.max_wav_len = model_args.max_wav_length
|
||||
self.max_text_len = model_args.max_text_length
|
||||
assert self.max_wav_len is not None and self.max_text_len is not None
|
||||
|
||||
# load specific vocabulary
|
||||
self.tokenizer = tokenizer
|
||||
self.samples = samples
|
||||
if not is_eval:
|
||||
random.seed(config.training_seed)
|
||||
# random.shuffle(self.samples)
|
||||
random.shuffle(self.samples)
|
||||
# order by language
|
||||
self.samples = key_samples_by_col(self.samples, "language")
|
||||
print(" > Sampling by language:", self.samples.keys())
|
||||
else:
|
||||
# for evaluation load and check samples that are corrupted to ensures the reproducibility
|
||||
self.check_eval_samples()
|
||||
|
||||
def check_eval_samples(self):
|
||||
print("Filtering invalid eval samples!!")
|
||||
new_samples = []
|
||||
for sample in self.samples:
|
||||
try:
|
||||
tseq, _, wav, _, _, _ = self.load_item(sample)
|
||||
except:
|
||||
pass
|
||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||
if wav is None or \
|
||||
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
|
||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||
continue
|
||||
new_samples.append(sample)
|
||||
self.samples = new_samples
|
||||
print("Total eval samples after filtering:", len(self.samples))
|
||||
|
||||
def get_text(self, text, lang):
|
||||
tokens = self.tokenizer.encode(text, lang)
|
||||
|
@ -118,13 +131,17 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
||||
|
||||
def __getitem__(self, index):
|
||||
# select a random language
|
||||
lang = random.choice(list(self.samples.keys()))
|
||||
# select random sample
|
||||
index = random.randint(0, len(self.samples[lang]) - 1)
|
||||
sample = self.samples[lang][index]
|
||||
# a unique id for each sampel to deal with fails
|
||||
sample_id = lang+"_"+str(index)
|
||||
if self.is_eval:
|
||||
sample = self.samples[index]
|
||||
sample_id = str(index)
|
||||
else:
|
||||
# select a random language
|
||||
lang = random.choice(list(self.samples.keys()))
|
||||
# select random sample
|
||||
index = random.randint(0, len(self.samples[lang]) - 1)
|
||||
sample = self.samples[lang][index]
|
||||
# a unique id for each sampel to deal with fails
|
||||
sample_id = lang+"_"+str(index)
|
||||
|
||||
# ignore samples that we already know that is not valid ones
|
||||
if sample_id in self.failed_samples:
|
||||
|
@ -167,11 +184,14 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
return res
|
||||
|
||||
def __len__(self):
|
||||
if self.is_eval:
|
||||
return len(self.samples)
|
||||
return sum([len(v) for v in self.samples.values()])
|
||||
|
||||
def collate_fn(self, batch):
|
||||
# convert list of dicts to dict of lists
|
||||
B = len(batch)
|
||||
|
||||
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
||||
|
||||
# stack for features that already have the same shape
|
||||
|
@ -198,5 +218,4 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
|
||||
batch["wav"] = wav_padded
|
||||
batch["padded_text"] = text_padded
|
||||
|
||||
return batch
|
||||
|
|
|
@ -327,7 +327,7 @@ class GPTTrainer(BaseTTS):
|
|||
else:
|
||||
# Todo: remove the randomness of dataset when it is eval
|
||||
# init dataloader
|
||||
dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate)
|
||||
dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate, is_eval)
|
||||
|
||||
# wait all the DDP process to be ready
|
||||
if num_gpus > 1:
|
||||
|
|
|
@ -355,7 +355,7 @@ if __name__ == "__main__":
|
|||
DASHBOARD_LOGGER = "tensorboard"
|
||||
LOGGER_URI = None
|
||||
RESTORE_PATH = None
|
||||
BATCH_SIZE = 2
|
||||
BATCH_SIZE = 10
|
||||
GRAD_ACUMM_STEPS = 1
|
||||
NUM_LOADERS = 1
|
||||
|
||||
|
|
Loading…
Reference in New Issue