split dataset outside preprocessor

This commit is contained in:
Eren Golge 2019-07-16 21:15:04 +02:00
parent b7036e458d
commit fd081c49b7
3 changed files with 32 additions and 10 deletions

View File

@ -146,7 +146,7 @@ def common_voice(root_path, meta_file):
return items return items
def libri_tts(root_path, meta_files=None, is_eval=False): def libri_tts(root_path, meta_files=None):
"""https://ai.google/tools/datasets/libri-tts/""" """https://ai.google/tools/datasets/libri-tts/"""
items = [] items = []
if meta_files is None: if meta_files is None:
@ -164,6 +164,4 @@ def libri_tts(root_path, meta_files=None, is_eval=False):
items.append([text, wav_file, speaker_name]) items.append([text, wav_file, speaker_name])
for item in items: for item in items:
assert os.path.exists(item[1]), f" [!] wav file is not exist - {item[1]}" assert os.path.exists(item[1]), f" [!] wav file is not exist - {item[1]}"
if meta_files is None:
return items[:500] if is_eval else items[500:]
return items return items

View File

@ -21,7 +21,8 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
load_config, lr_decay, load_config, lr_decay,
remove_experiment_folder, save_best_model, remove_experiment_folder, save_best_model,
save_checkpoint, sequence_mask, weight_decay, save_checkpoint, sequence_mask, weight_decay,
set_init_dict, copy_config_file, setup_model) set_init_dict, copy_config_file, setup_model,
split_dataset)
from utils.logger import Logger from utils.logger import Logger
from utils.speakers import load_speaker_mapping, save_speaker_mapping, \ from utils.speakers import load_speaker_mapping, save_speaker_mapping, \
get_speakers get_speakers
@ -44,15 +45,15 @@ def setup_loader(is_val=False, verbose=False):
global meta_data_train global meta_data_train
global meta_data_eval global meta_data_eval
if "meta_data_train" not in globals(): if "meta_data_train" not in globals():
if c.meta_file_train: if c.meta_file_train is not None:
meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_train) meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_train)
else: else:
meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path, is_eval=False) meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path)
if "meta_data_eval" not in globals(): if "meta_data_eval" not in globals() and c.run_eval:
if c.meta_file_val: if c.meta_file_val is not None:
meta_data_eval = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_val) meta_data_eval = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_val)
else: else:
meta_data_eval = get_preprocessor_by_name(c.dataset)(c.data_path, is_eval=True) meta_data_eval, meta_data_train = split_dataset(meta_data_train)
if is_val and not c.run_eval: if is_val and not c.run_eval:
loader = None loader = None
else: else:

View File

@ -10,7 +10,7 @@ import torch
import subprocess import subprocess
import importlib import importlib
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict, Counter
from torch.autograd import Variable from torch.autograd import Variable
from utils.text import text_to_sequence from utils.text import text_to_sequence
@ -287,3 +287,26 @@ def setup_model(num_chars, num_speakers, c):
location_attn=c.location_attn, location_attn=c.location_attn,
separate_stopnet=c.separate_stopnet) separate_stopnet=c.separate_stopnet)
return model return model
def split_dataset(items):
is_multi_speaker = False
speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = 500 if 500 < len(items) * 0.01 else int(len(items) * 0.01)
np.random.seed(0)
np.random.shuffle(items)
if is_multi_speaker:
items_eval = []
# most stupid code ever -- Fix it !
while len(items_eval) < eval_split_size:
speakers = [item[-1] for item in items]
speaker_counter = Counter(speakers)
item_idx = np.random.randint(0, len(items))
if speaker_counter[items[item_idx][-1]] > 1:
items_eval.append(items[item_idx])
del items[item_idx]
return items_eval, items
else:
return items[:eval_split_size], items[eval_split_size:]