diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index a0aba29a..2b165951 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -16,13 +16,14 @@ def split_dataset(items): np.random.shuffle(items) if is_multi_speaker: items_eval = [] - # most stupid code ever -- Fix it ! + speakers = [item[-1] for item in items] + speaker_counter = Counter(speakers) while len(items_eval) < eval_split_size: - speakers = [item[-1] for item in items] - speaker_counter = Counter(speakers) item_idx = np.random.randint(0, len(items)) - if speaker_counter[items[item_idx][-1]] > 1: + speaker_to_be_removed = items[item_idx][-1] + if speaker_counter[speaker_to_be_removed] > 1: items_eval.append(items[item_idx]) + speaker_counter[speaker_to_be_removed] -= 1 del items[item_idx] return items_eval, items return items[:eval_split_size], items[eval_split_size:] @@ -127,6 +128,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): return model + def check_config_tts(c): check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str) check_argument('run_name', c, restricted=True, val_type=str)