rename preprocess.py -> formatters.py

This commit is contained in:
Eren Gölge 2021-05-25 10:38:05 +02:00
parent facb782851
commit f07209d2e0
2 changed files with 26 additions and 94 deletions

View File

@ -1,93 +1,12 @@
import os
import re
import sys
import xml.etree.ElementTree as ET
from collections import Counter
from glob import glob
from pathlib import Path
from typing import List
import numpy as np
from tqdm import tqdm
####################
# UTILITIES
####################
def split_dataset(items):
speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = min(500, int(len(items) * 0.01))
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
np.random.seed(0)
np.random.shuffle(items)
if is_multi_speaker:
items_eval = []
speakers = [item[-1] for item in items]
speaker_counter = Counter(speakers)
while len(items_eval) < eval_split_size:
item_idx = np.random.randint(0, len(items))
speaker_to_be_removed = items[item_idx][-1]
if speaker_counter[speaker_to_be_removed] > 1:
items_eval.append(items[item_idx])
speaker_counter[speaker_to_be_removed] -= 1
del items[item_idx]
return items_eval, items
return items[:eval_split_size], items[eval_split_size:]
def load_meta_data(datasets, eval_split=True):
meta_data_train_all = []
meta_data_eval_all = [] if eval_split else None
for dataset in datasets:
name = dataset["name"]
root_path = dataset["path"]
meta_file_train = dataset["meta_file_train"]
meta_file_val = dataset["meta_file_val"]
# setup the right data processor
preprocessor = get_preprocessor_by_name(name)
# load train set
meta_data_train = preprocessor(root_path, meta_file_train)
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
# load evaluation split if set
if eval_split:
if meta_file_val:
meta_data_eval = preprocessor(root_path, meta_file_val)
else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_eval_all += meta_data_eval
meta_data_train_all += meta_data_train
# load attention masks for duration predictor training
if dataset.meta_file_attn_mask:
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
for idx, ins in enumerate(meta_data_train_all):
attn_file = meta_data[ins[1]].strip()
meta_data_train_all[idx].append(attn_file)
if meta_data_eval_all:
for idx, ins in enumerate(meta_data_eval_all):
attn_file = meta_data[ins[1]].strip()
meta_data_eval_all[idx].append(attn_file)
return meta_data_train_all, meta_data_eval_all
def load_attention_mask_meta_data(metafile_path):
"""Load meta data file created by compute_attention_masks.py"""
with open(metafile_path, "r") as f:
lines = f.readlines()
meta_data = []
for line in lines:
wav_file, attn_file = line.split("|")
meta_data.append([wav_file, attn_file])
return meta_data
def get_preprocessor_by_name(name):
"""Returns the respective preprocessing function."""
thismodule = sys.modules[__name__]
return getattr(thismodule, name.lower())
########################
# DATASETS

View File

@ -30,16 +30,16 @@ def init_arguments(argv):
parser.add_argument(
"--continue_path",
type=str,
help=(
"Training output folder to continue training. Used to continue "
"a training. If it is used, 'config_path' is ignored."
),
help=("Training output folder to continue training. Used to continue "
"a training. If it is used, 'config_path' is ignored."),
default="",
required="--config_path" not in argv,
)
parser.add_argument(
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
)
"--restore_path",
type=str,
help="Model file to be restored. Use to finetune a model.",
default="")
parser.add_argument(
"--best_path",
type=str,
@ -49,12 +49,23 @@ def init_arguments(argv):
),
default="",
)
parser.add_argument("--config_path",
type=str,
help="Path to config file for training.",
required="--continue_path" not in argv)
parser.add_argument("--debug",
type=bool,
default=False,
help="Do not verify commit integrity to run training.")
parser.add_argument(
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv
)
parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.")
parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.")
parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.")
"--rank",
type=int,
default=0,
help="DISTRIBUTED: process rank for distributed training.")
parser.add_argument("--group_id",
type=str,
default="",
help="DISTRIBUTED: process group id.")
return parser
@ -149,7 +160,8 @@ def process_args(args):
print(" > Mixed precision mode is ON")
experiment_path = args.continue_path
if not experiment_path:
experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug)
experiment_path = create_experiment_folder(config.output_path,
config.run_name, args.debug)
audio_path = os.path.join(experiment_path, "test_audios")
# setup rank 0 process in distributed training
tb_logger = None
@ -170,7 +182,8 @@ def process_args(args):
os.chmod(experiment_path, 0o775)
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
# write model desc to tensorboard
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>",
0)
c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, tb_logger