Merge pull request #530 from mueller91/fix_split_dataset

fix: split_dataset
This commit is contained in:
Eren Gölge 2020-09-28 12:42:40 +02:00 committed by GitHub
commit cf02ace5b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 4 deletions

View File

@ -16,13 +16,14 @@ def split_dataset(items):
np.random.shuffle(items)
if is_multi_speaker:
items_eval = []
# most stupid code ever -- Fix it !
speakers = [item[-1] for item in items]
speaker_counter = Counter(speakers)
while len(items_eval) < eval_split_size:
speakers = [item[-1] for item in items]
speaker_counter = Counter(speakers)
item_idx = np.random.randint(0, len(items))
if speaker_counter[items[item_idx][-1]] > 1:
speaker_to_be_removed = items[item_idx][-1]
if speaker_counter[speaker_to_be_removed] > 1:
items_eval.append(items[item_idx])
speaker_counter[speaker_to_be_removed] -= 1
del items[item_idx]
return items_eval, items
return items[:eval_split_size], items[eval_split_size:]
@ -127,6 +128,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
return model
def check_config_tts(c):
check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
check_argument('run_name', c, restricted=True, val_type=str)