mirror of https://github.com/coqui-ai/TTS.git
move split_dataset
This commit is contained in:
parent
9c18e40f64
commit
93a00373f6
|
@ -2,19 +2,41 @@ import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
from collections import Counter
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.tts.utils.generic_utils import split_dataset
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# UTILITIES
|
# UTILITIES
|
||||||
####################
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
def split_dataset(items):
|
||||||
|
speakers = [item[-1] for item in items]
|
||||||
|
is_multi_speaker = len(set(speakers)) > 1
|
||||||
|
eval_split_size = min(500, int(len(items) * 0.01))
|
||||||
|
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
|
||||||
|
np.random.seed(0)
|
||||||
|
np.random.shuffle(items)
|
||||||
|
if is_multi_speaker:
|
||||||
|
items_eval = []
|
||||||
|
speakers = [item[-1] for item in items]
|
||||||
|
speaker_counter = Counter(speakers)
|
||||||
|
while len(items_eval) < eval_split_size:
|
||||||
|
item_idx = np.random.randint(0, len(items))
|
||||||
|
speaker_to_be_removed = items[item_idx][-1]
|
||||||
|
if speaker_counter[speaker_to_be_removed] > 1:
|
||||||
|
items_eval.append(items[item_idx])
|
||||||
|
speaker_counter[speaker_to_be_removed] -= 1
|
||||||
|
del items[item_idx]
|
||||||
|
return items_eval, items
|
||||||
|
return items[:eval_split_size], items[eval_split_size:]
|
||||||
|
|
||||||
|
|
||||||
def load_meta_data(datasets, eval_split=True):
|
def load_meta_data(datasets, eval_split=True):
|
||||||
meta_data_train_all = []
|
meta_data_train_all = []
|
||||||
meta_data_eval_all = [] if eval_split else None
|
meta_data_eval_all = [] if eval_split else None
|
||||||
|
@ -38,7 +60,7 @@ def load_meta_data(datasets, eval_split=True):
|
||||||
meta_data_train_all += meta_data_train
|
meta_data_train_all += meta_data_train
|
||||||
# load attention masks for duration predictor training
|
# load attention masks for duration predictor training
|
||||||
if dataset.meta_file_attn_mask is not None:
|
if dataset.meta_file_attn_mask is not None:
|
||||||
meta_data = dict(load_attention_mask_meta_data(dataset['meta_file_attn_mask']))
|
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
||||||
for idx, ins in enumerate(meta_data_train_all):
|
for idx, ins in enumerate(meta_data_train_all):
|
||||||
attn_file = meta_data[ins[1]].strip()
|
attn_file = meta_data[ins[1]].strip()
|
||||||
meta_data_train_all[idx].append(attn_file)
|
meta_data_train_all[idx].append(attn_file)
|
||||||
|
|
|
@ -1,32 +1,8 @@
|
||||||
import importlib
|
import torch
|
||||||
import re
|
|
||||||
from collections import Counter
|
|
||||||
|
|
||||||
from TTS.utils.generic_utils import find_module
|
from TTS.utils.generic_utils import find_module
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(items):
|
|
||||||
speakers = [item[-1] for item in items]
|
|
||||||
is_multi_speaker = len(set(speakers)) > 1
|
|
||||||
eval_split_size = min(500, int(len(items) * 0.01))
|
|
||||||
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
|
|
||||||
np.random.seed(0)
|
|
||||||
np.random.shuffle(items)
|
|
||||||
if is_multi_speaker:
|
|
||||||
items_eval = []
|
|
||||||
speakers = [item[-1] for item in items]
|
|
||||||
speaker_counter = Counter(speakers)
|
|
||||||
while len(items_eval) < eval_split_size:
|
|
||||||
item_idx = np.random.randint(0, len(items))
|
|
||||||
speaker_to_be_removed = items[item_idx][-1]
|
|
||||||
if speaker_counter[speaker_to_be_removed] > 1:
|
|
||||||
items_eval.append(items[item_idx])
|
|
||||||
speaker_counter[speaker_to_be_removed] -= 1
|
|
||||||
del items[item_idx]
|
|
||||||
return items_eval, items
|
|
||||||
return items[:eval_split_size], items[eval_split_size:]
|
|
||||||
|
|
||||||
|
|
||||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||||
def sequence_mask(sequence_length, max_len=None):
|
def sequence_mask(sequence_length, max_len=None):
|
||||||
if max_len is None:
|
if max_len is None:
|
||||||
|
|
Loading…
Reference in New Issue