mirror of https://github.com/coqui-ai/TTS.git
rename preprocess.py -> formatters.py
This commit is contained in:
parent
b9bccbb243
commit
a20a1c7d06
|
@ -1,93 +1,12 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from collections import Counter
|
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
####################
|
|
||||||
# UTILITIES
|
|
||||||
####################
|
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(items):
|
|
||||||
speakers = [item[-1] for item in items]
|
|
||||||
is_multi_speaker = len(set(speakers)) > 1
|
|
||||||
eval_split_size = min(500, int(len(items) * 0.01))
|
|
||||||
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
|
|
||||||
np.random.seed(0)
|
|
||||||
np.random.shuffle(items)
|
|
||||||
if is_multi_speaker:
|
|
||||||
items_eval = []
|
|
||||||
speakers = [item[-1] for item in items]
|
|
||||||
speaker_counter = Counter(speakers)
|
|
||||||
while len(items_eval) < eval_split_size:
|
|
||||||
item_idx = np.random.randint(0, len(items))
|
|
||||||
speaker_to_be_removed = items[item_idx][-1]
|
|
||||||
if speaker_counter[speaker_to_be_removed] > 1:
|
|
||||||
items_eval.append(items[item_idx])
|
|
||||||
speaker_counter[speaker_to_be_removed] -= 1
|
|
||||||
del items[item_idx]
|
|
||||||
return items_eval, items
|
|
||||||
return items[:eval_split_size], items[eval_split_size:]
|
|
||||||
|
|
||||||
|
|
||||||
def load_meta_data(datasets, eval_split=True):
|
|
||||||
meta_data_train_all = []
|
|
||||||
meta_data_eval_all = [] if eval_split else None
|
|
||||||
for dataset in datasets:
|
|
||||||
name = dataset["name"]
|
|
||||||
root_path = dataset["path"]
|
|
||||||
meta_file_train = dataset["meta_file_train"]
|
|
||||||
meta_file_val = dataset["meta_file_val"]
|
|
||||||
# setup the right data processor
|
|
||||||
preprocessor = get_preprocessor_by_name(name)
|
|
||||||
# load train set
|
|
||||||
meta_data_train = preprocessor(root_path, meta_file_train)
|
|
||||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
|
||||||
# load evaluation split if set
|
|
||||||
if eval_split:
|
|
||||||
if meta_file_val:
|
|
||||||
meta_data_eval = preprocessor(root_path, meta_file_val)
|
|
||||||
else:
|
|
||||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
|
||||||
meta_data_eval_all += meta_data_eval
|
|
||||||
meta_data_train_all += meta_data_train
|
|
||||||
# load attention masks for duration predictor training
|
|
||||||
if dataset.meta_file_attn_mask:
|
|
||||||
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
|
||||||
for idx, ins in enumerate(meta_data_train_all):
|
|
||||||
attn_file = meta_data[ins[1]].strip()
|
|
||||||
meta_data_train_all[idx].append(attn_file)
|
|
||||||
if meta_data_eval_all:
|
|
||||||
for idx, ins in enumerate(meta_data_eval_all):
|
|
||||||
attn_file = meta_data[ins[1]].strip()
|
|
||||||
meta_data_eval_all[idx].append(attn_file)
|
|
||||||
return meta_data_train_all, meta_data_eval_all
|
|
||||||
|
|
||||||
|
|
||||||
def load_attention_mask_meta_data(metafile_path):
|
|
||||||
"""Load meta data file created by compute_attention_masks.py"""
|
|
||||||
with open(metafile_path, "r") as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
|
|
||||||
meta_data = []
|
|
||||||
for line in lines:
|
|
||||||
wav_file, attn_file = line.split("|")
|
|
||||||
meta_data.append([wav_file, attn_file])
|
|
||||||
return meta_data
|
|
||||||
|
|
||||||
|
|
||||||
def get_preprocessor_by_name(name):
|
|
||||||
"""Returns the respective preprocessing function."""
|
|
||||||
thismodule = sys.modules[__name__]
|
|
||||||
return getattr(thismodule, name.lower())
|
|
||||||
|
|
||||||
|
|
||||||
########################
|
########################
|
||||||
# DATASETS
|
# DATASETS
|
|
@ -30,16 +30,16 @@ def init_arguments(argv):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--continue_path",
|
"--continue_path",
|
||||||
type=str,
|
type=str,
|
||||||
help=(
|
help=("Training output folder to continue training. Used to continue "
|
||||||
"Training output folder to continue training. Used to continue "
|
"a training. If it is used, 'config_path' is ignored."),
|
||||||
"a training. If it is used, 'config_path' is ignored."
|
|
||||||
),
|
|
||||||
default="",
|
default="",
|
||||||
required="--config_path" not in argv,
|
required="--config_path" not in argv,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
|
"--restore_path",
|
||||||
)
|
type=str,
|
||||||
|
help="Model file to be restored. Use to finetune a model.",
|
||||||
|
default="")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--best_path",
|
"--best_path",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -49,12 +49,23 @@ def init_arguments(argv):
|
||||||
),
|
),
|
||||||
default="",
|
default="",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--config_path",
|
||||||
|
type=str,
|
||||||
|
help="Path to config file for training.",
|
||||||
|
required="--continue_path" not in argv)
|
||||||
|
parser.add_argument("--debug",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="Do not verify commit integrity to run training.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv
|
"--rank",
|
||||||
)
|
type=int,
|
||||||
parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.")
|
default=0,
|
||||||
parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.")
|
help="DISTRIBUTED: process rank for distributed training.")
|
||||||
parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.")
|
parser.add_argument("--group_id",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="DISTRIBUTED: process group id.")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -149,7 +160,8 @@ def process_args(args):
|
||||||
print(" > Mixed precision mode is ON")
|
print(" > Mixed precision mode is ON")
|
||||||
experiment_path = args.continue_path
|
experiment_path = args.continue_path
|
||||||
if not experiment_path:
|
if not experiment_path:
|
||||||
experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug)
|
experiment_path = create_experiment_folder(config.output_path,
|
||||||
|
config.run_name, args.debug)
|
||||||
audio_path = os.path.join(experiment_path, "test_audios")
|
audio_path = os.path.join(experiment_path, "test_audios")
|
||||||
# setup rank 0 process in distributed training
|
# setup rank 0 process in distributed training
|
||||||
tb_logger = None
|
tb_logger = None
|
||||||
|
@ -170,7 +182,8 @@ def process_args(args):
|
||||||
os.chmod(experiment_path, 0o775)
|
os.chmod(experiment_path, 0o775)
|
||||||
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
||||||
# write model desc to tensorboard
|
# write model desc to tensorboard
|
||||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>",
|
||||||
|
0)
|
||||||
c_logger = ConsoleLogger()
|
c_logger = ConsoleLogger()
|
||||||
return config, experiment_path, audio_path, c_logger, tb_logger
|
return config, experiment_path, audio_path, c_logger, tb_logger
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue