From 93a00373f6d5d15eb2d7f616cfed56725d9f8421 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 3 May 2021 16:43:54 +0200 Subject: [PATCH] move split_dataset --- TTS/tts/datasets/preprocess.py | 28 +++++++++++++++++++++++++--- TTS/tts/utils/generic_utils.py | 26 +------------------------- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/TTS/tts/datasets/preprocess.py b/TTS/tts/datasets/preprocess.py index d6040493..4523d70b 100644 --- a/TTS/tts/datasets/preprocess.py +++ b/TTS/tts/datasets/preprocess.py @@ -2,19 +2,41 @@ import os import re import sys import xml.etree.ElementTree as ET +from collections import Counter from glob import glob from pathlib import Path from typing import List +import numpy as np from tqdm import tqdm -from TTS.tts.utils.generic_utils import split_dataset - #################### # UTILITIES #################### +def split_dataset(items): + speakers = [item[-1] for item in items] + is_multi_speaker = len(set(speakers)) > 1 + eval_split_size = min(500, int(len(items) * 0.01)) + assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples." + np.random.seed(0) + np.random.shuffle(items) + if is_multi_speaker: + items_eval = [] + speakers = [item[-1] for item in items] + speaker_counter = Counter(speakers) + while len(items_eval) < eval_split_size: + item_idx = np.random.randint(0, len(items)) + speaker_to_be_removed = items[item_idx][-1] + if speaker_counter[speaker_to_be_removed] > 1: + items_eval.append(items[item_idx]) + speaker_counter[speaker_to_be_removed] -= 1 + del items[item_idx] + return items_eval, items + return items[:eval_split_size], items[eval_split_size:] + + def load_meta_data(datasets, eval_split=True): meta_data_train_all = [] meta_data_eval_all = [] if eval_split else None @@ -38,7 +60,7 @@ def load_meta_data(datasets, eval_split=True): meta_data_train_all += meta_data_train # load attention masks for duration predictor training if dataset.meta_file_attn_mask is not None: - meta_data = dict(load_attention_mask_meta_data(dataset['meta_file_attn_mask'])) + meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) for idx, ins in enumerate(meta_data_train_all): attn_file = meta_data[ins[1]].strip() meta_data_train_all[idx].append(attn_file) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 9711c868..9f17da0b 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -1,32 +1,8 @@ -import importlib -import re -from collections import Counter +import torch from TTS.utils.generic_utils import find_module -def split_dataset(items): - speakers = [item[-1] for item in items] - is_multi_speaker = len(set(speakers)) > 1 - eval_split_size = min(500, int(len(items) * 0.01)) - assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples." - np.random.seed(0) - np.random.shuffle(items) - if is_multi_speaker: - items_eval = [] - speakers = [item[-1] for item in items] - speaker_counter = Counter(speakers) - while len(items_eval) < eval_split_size: - item_idx = np.random.randint(0, len(items)) - speaker_to_be_removed = items[item_idx][-1] - if speaker_counter[speaker_to_be_removed] > 1: - items_eval.append(items[item_idx]) - speaker_counter[speaker_to_be_removed] -= 1 - del items[item_idx] - return items_eval, items - return items[:eval_split_size], items[eval_split_size:] - - # from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 def sequence_mask(sequence_length, max_len=None): if max_len is None: