mirror of https://github.com/coqui-ai/TTS.git
Add reproducible evaluation
This commit is contained in:
parent
40a4e631ea
commit
47d613df3a
|
@ -63,35 +63,48 @@ def load_audio(audiopath, sampling_rate):
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
class XTTSDataset(torch.utils.data.Dataset):
|
class XTTSDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, config, samples, tokenizer, sample_rate):
|
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
|
||||||
self.config = config
|
self.config = config
|
||||||
model_args = config.model_args
|
model_args = config.model_args
|
||||||
self.failed_samples = set()
|
self.failed_samples = set()
|
||||||
self.debug_failures = model_args.debug_loading_failures
|
self.debug_failures = model_args.debug_loading_failures
|
||||||
self.max_conditioning_length = model_args.max_conditioning_length
|
self.max_conditioning_length = model_args.max_conditioning_length
|
||||||
self.min_conditioning_length = model_args.min_conditioning_length
|
self.min_conditioning_length = model_args.min_conditioning_length
|
||||||
|
self.is_eval = is_eval
|
||||||
# self.samples = []
|
self.tokenizer = tokenizer
|
||||||
# cache the samples and added type "0" for all samples
|
|
||||||
# ToDo: find a better way to deal with type
|
|
||||||
# for item in samples:
|
|
||||||
# self.samples.append([item['audio_file'], item["text"], 0])
|
|
||||||
self.samples = samples
|
|
||||||
random.seed(config.training_seed)
|
|
||||||
# random.shuffle(self.samples)
|
|
||||||
random.shuffle(self.samples)
|
|
||||||
# order by language
|
|
||||||
self.samples = key_samples_by_col(self.samples, "language")
|
|
||||||
print(" > Sampling by language:", self.samples.keys())
|
|
||||||
|
|
||||||
# use always the output sampling rate to load in the highest quality
|
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.max_wav_len = model_args.max_wav_length
|
self.max_wav_len = model_args.max_wav_length
|
||||||
self.max_text_len = model_args.max_text_length
|
self.max_text_len = model_args.max_text_length
|
||||||
assert self.max_wav_len is not None and self.max_text_len is not None
|
assert self.max_wav_len is not None and self.max_text_len is not None
|
||||||
|
|
||||||
# load specific vocabulary
|
self.samples = samples
|
||||||
self.tokenizer = tokenizer
|
if not is_eval:
|
||||||
|
random.seed(config.training_seed)
|
||||||
|
# random.shuffle(self.samples)
|
||||||
|
random.shuffle(self.samples)
|
||||||
|
# order by language
|
||||||
|
self.samples = key_samples_by_col(self.samples, "language")
|
||||||
|
print(" > Sampling by language:", self.samples.keys())
|
||||||
|
else:
|
||||||
|
# for evaluation load and check samples that are corrupted to ensures the reproducibility
|
||||||
|
self.check_eval_samples()
|
||||||
|
|
||||||
|
def check_eval_samples(self):
|
||||||
|
print("Filtering invalid eval samples!!")
|
||||||
|
new_samples = []
|
||||||
|
for sample in self.samples:
|
||||||
|
try:
|
||||||
|
tseq, _, wav, _, _, _ = self.load_item(sample)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||||
|
if wav is None or \
|
||||||
|
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
|
||||||
|
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||||
|
continue
|
||||||
|
new_samples.append(sample)
|
||||||
|
self.samples = new_samples
|
||||||
|
print("Total eval samples after filtering:", len(self.samples))
|
||||||
|
|
||||||
def get_text(self, text, lang):
|
def get_text(self, text, lang):
|
||||||
tokens = self.tokenizer.encode(text, lang)
|
tokens = self.tokenizer.encode(text, lang)
|
||||||
|
@ -118,13 +131,17 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
# select a random language
|
if self.is_eval:
|
||||||
lang = random.choice(list(self.samples.keys()))
|
sample = self.samples[index]
|
||||||
# select random sample
|
sample_id = str(index)
|
||||||
index = random.randint(0, len(self.samples[lang]) - 1)
|
else:
|
||||||
sample = self.samples[lang][index]
|
# select a random language
|
||||||
# a unique id for each sampel to deal with fails
|
lang = random.choice(list(self.samples.keys()))
|
||||||
sample_id = lang+"_"+str(index)
|
# select random sample
|
||||||
|
index = random.randint(0, len(self.samples[lang]) - 1)
|
||||||
|
sample = self.samples[lang][index]
|
||||||
|
# a unique id for each sampel to deal with fails
|
||||||
|
sample_id = lang+"_"+str(index)
|
||||||
|
|
||||||
# ignore samples that we already know that is not valid ones
|
# ignore samples that we already know that is not valid ones
|
||||||
if sample_id in self.failed_samples:
|
if sample_id in self.failed_samples:
|
||||||
|
@ -167,11 +184,14 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
if self.is_eval:
|
||||||
|
return len(self.samples)
|
||||||
return sum([len(v) for v in self.samples.values()])
|
return sum([len(v) for v in self.samples.values()])
|
||||||
|
|
||||||
def collate_fn(self, batch):
|
def collate_fn(self, batch):
|
||||||
# convert list of dicts to dict of lists
|
# convert list of dicts to dict of lists
|
||||||
B = len(batch)
|
B = len(batch)
|
||||||
|
|
||||||
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
||||||
|
|
||||||
# stack for features that already have the same shape
|
# stack for features that already have the same shape
|
||||||
|
@ -198,5 +218,4 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
batch["wav"] = wav_padded
|
batch["wav"] = wav_padded
|
||||||
batch["padded_text"] = text_padded
|
batch["padded_text"] = text_padded
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
|
@ -327,7 +327,7 @@ class GPTTrainer(BaseTTS):
|
||||||
else:
|
else:
|
||||||
# Todo: remove the randomness of dataset when it is eval
|
# Todo: remove the randomness of dataset when it is eval
|
||||||
# init dataloader
|
# init dataloader
|
||||||
dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate)
|
dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate, is_eval)
|
||||||
|
|
||||||
# wait all the DDP process to be ready
|
# wait all the DDP process to be ready
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
|
|
|
@ -355,7 +355,7 @@ if __name__ == "__main__":
|
||||||
DASHBOARD_LOGGER = "tensorboard"
|
DASHBOARD_LOGGER = "tensorboard"
|
||||||
LOGGER_URI = None
|
LOGGER_URI = None
|
||||||
RESTORE_PATH = None
|
RESTORE_PATH = None
|
||||||
BATCH_SIZE = 2
|
BATCH_SIZE = 10
|
||||||
GRAD_ACUMM_STEPS = 1
|
GRAD_ACUMM_STEPS = 1
|
||||||
NUM_LOADERS = 1
|
NUM_LOADERS = 1
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue