mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #530 from mueller91/fix_split_dataset
fix: split_dataset
This commit is contained in:
commit
cf02ace5b7
|
@ -16,13 +16,14 @@ def split_dataset(items):
|
|||
np.random.shuffle(items)
|
||||
if is_multi_speaker:
|
||||
items_eval = []
|
||||
# most stupid code ever -- Fix it !
|
||||
speakers = [item[-1] for item in items]
|
||||
speaker_counter = Counter(speakers)
|
||||
while len(items_eval) < eval_split_size:
|
||||
speakers = [item[-1] for item in items]
|
||||
speaker_counter = Counter(speakers)
|
||||
item_idx = np.random.randint(0, len(items))
|
||||
if speaker_counter[items[item_idx][-1]] > 1:
|
||||
speaker_to_be_removed = items[item_idx][-1]
|
||||
if speaker_counter[speaker_to_be_removed] > 1:
|
||||
items_eval.append(items[item_idx])
|
||||
speaker_counter[speaker_to_be_removed] -= 1
|
||||
del items[item_idx]
|
||||
return items_eval, items
|
||||
return items[:eval_split_size], items[eval_split_size:]
|
||||
|
@ -127,6 +128,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
return model
|
||||
|
||||
|
||||
|
||||
def check_config_tts(c):
|
||||
check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
|
||||
check_argument('run_name', c, restricted=True, val_type=str)
|
||||
|
|
Loading…
Reference in New Issue