move split_dataset

This commit is contained in:
Eren Gölge 2021-05-03 16:43:54 +02:00
parent 9c18e40f64
commit 93a00373f6
2 changed files with 26 additions and 28 deletions

View File

@ -2,19 +2,41 @@ import os
import re import re
import sys import sys
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from collections import Counter
from glob import glob from glob import glob
from pathlib import Path from pathlib import Path
from typing import List from typing import List
import numpy as np
from tqdm import tqdm from tqdm import tqdm
from TTS.tts.utils.generic_utils import split_dataset
#################### ####################
# UTILITIES # UTILITIES
#################### ####################
def split_dataset(items):
speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = min(500, int(len(items) * 0.01))
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
np.random.seed(0)
np.random.shuffle(items)
if is_multi_speaker:
items_eval = []
speakers = [item[-1] for item in items]
speaker_counter = Counter(speakers)
while len(items_eval) < eval_split_size:
item_idx = np.random.randint(0, len(items))
speaker_to_be_removed = items[item_idx][-1]
if speaker_counter[speaker_to_be_removed] > 1:
items_eval.append(items[item_idx])
speaker_counter[speaker_to_be_removed] -= 1
del items[item_idx]
return items_eval, items
return items[:eval_split_size], items[eval_split_size:]
def load_meta_data(datasets, eval_split=True): def load_meta_data(datasets, eval_split=True):
meta_data_train_all = [] meta_data_train_all = []
meta_data_eval_all = [] if eval_split else None meta_data_eval_all = [] if eval_split else None
@ -38,7 +60,7 @@ def load_meta_data(datasets, eval_split=True):
meta_data_train_all += meta_data_train meta_data_train_all += meta_data_train
# load attention masks for duration predictor training # load attention masks for duration predictor training
if dataset.meta_file_attn_mask is not None: if dataset.meta_file_attn_mask is not None:
meta_data = dict(load_attention_mask_meta_data(dataset['meta_file_attn_mask'])) meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
for idx, ins in enumerate(meta_data_train_all): for idx, ins in enumerate(meta_data_train_all):
attn_file = meta_data[ins[1]].strip() attn_file = meta_data[ins[1]].strip()
meta_data_train_all[idx].append(attn_file) meta_data_train_all[idx].append(attn_file)

View File

@ -1,32 +1,8 @@
import importlib import torch
import re
from collections import Counter
from TTS.utils.generic_utils import find_module from TTS.utils.generic_utils import find_module
def split_dataset(items):
speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = min(500, int(len(items) * 0.01))
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
np.random.seed(0)
np.random.shuffle(items)
if is_multi_speaker:
items_eval = []
speakers = [item[-1] for item in items]
speaker_counter = Counter(speakers)
while len(items_eval) < eval_split_size:
item_idx = np.random.randint(0, len(items))
speaker_to_be_removed = items[item_idx][-1]
if speaker_counter[speaker_to_be_removed] > 1:
items_eval.append(items[item_idx])
speaker_counter[speaker_to_be_removed] -= 1
del items[item_idx]
return items_eval, items
return items[:eval_split_size], items[eval_split_size:]
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 # from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def sequence_mask(sequence_length, max_len=None): def sequence_mask(sequence_length, max_len=None):
if max_len is None: if max_len is None: