fix: split_dataset() runtime reduced from O(N * |items|) to O(N) where N is the size of the eval split (max 500)

I notice a significant speedup on the initial loading of large datasets such as common voice (from minutes to seconds)
This commit is contained in:
mueller91 2020-09-23 23:27:51 +02:00
parent cfeeef7a7f
commit 227b9c8864
1 changed files with 6 additions and 4 deletions

View File

@ -16,13 +16,14 @@ def split_dataset(items):
np.random.shuffle(items)
if is_multi_speaker:
items_eval = []
# most stupid code ever -- Fix it !
speakers = [item[-1] for item in items]
speaker_counter = Counter(speakers)
while len(items_eval) < eval_split_size:
speakers = [item[-1] for item in items]
speaker_counter = Counter(speakers)
item_idx = np.random.randint(0, len(items))
if speaker_counter[items[item_idx][-1]] > 1:
speaker_to_be_removed = items[item_idx][-1]
if speaker_counter[speaker_to_be_removed] > 1:
items_eval.append(items[item_idx])
speaker_counter[speaker_to_be_removed] -= 1
del items[item_idx]
return items_eval, items
return items[:eval_split_size], items[eval_split_size:]
@ -127,6 +128,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
return model
def check_config_tts(c):
check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
check_argument('run_name', c, restricted=True, val_type=str)