move split_dataset

This commit is contained in:
Eren Gölge 2021-05-03 16:43:54 +02:00
parent 9c18e40f64
commit 93a00373f6
2 changed files with 26 additions and 28 deletions

View File

@ -2,19 +2,41 @@ import os
import re
import sys
import xml.etree.ElementTree as ET
from collections import Counter
from glob import glob
from pathlib import Path
from typing import List
import numpy as np
from tqdm import tqdm
from TTS.tts.utils.generic_utils import split_dataset
####################
# UTILITIES
####################
def split_dataset(items):
speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = min(500, int(len(items) * 0.01))
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
np.random.seed(0)
np.random.shuffle(items)
if is_multi_speaker:
items_eval = []
speakers = [item[-1] for item in items]
speaker_counter = Counter(speakers)
while len(items_eval) < eval_split_size:
item_idx = np.random.randint(0, len(items))
speaker_to_be_removed = items[item_idx][-1]
if speaker_counter[speaker_to_be_removed] > 1:
items_eval.append(items[item_idx])
speaker_counter[speaker_to_be_removed] -= 1
del items[item_idx]
return items_eval, items
return items[:eval_split_size], items[eval_split_size:]
def load_meta_data(datasets, eval_split=True):
meta_data_train_all = []
meta_data_eval_all = [] if eval_split else None
@ -38,7 +60,7 @@ def load_meta_data(datasets, eval_split=True):
meta_data_train_all += meta_data_train
# load attention masks for duration predictor training
if dataset.meta_file_attn_mask is not None:
meta_data = dict(load_attention_mask_meta_data(dataset['meta_file_attn_mask']))
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
for idx, ins in enumerate(meta_data_train_all):
attn_file = meta_data[ins[1]].strip()
meta_data_train_all[idx].append(attn_file)

View File

@ -1,32 +1,8 @@
import importlib
import re
from collections import Counter
import torch
from TTS.utils.generic_utils import find_module
def split_dataset(items):
speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = min(500, int(len(items) * 0.01))
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
np.random.seed(0)
np.random.shuffle(items)
if is_multi_speaker:
items_eval = []
speakers = [item[-1] for item in items]
speaker_counter = Counter(speakers)
while len(items_eval) < eval_split_size:
item_idx = np.random.randint(0, len(items))
speaker_to_be_removed = items[item_idx][-1]
if speaker_counter[speaker_to_be_removed] > 1:
items_eval.append(items[item_idx])
speaker_counter[speaker_to_be_removed] -= 1
del items[item_idx]
return items_eval, items
return items[:eval_split_size], items[eval_split_size:]
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def sequence_mask(sequence_length, max_len=None):
if max_len is None: