Enable custom formatter in load_tts_samples

This commit is contained in:
Eren Gölge 2021-10-26 13:06:22 +02:00
parent 7c10574931
commit 0cac3f330a
1 changed files with 21 additions and 7 deletions

View File

@ -1,7 +1,7 @@
import sys import sys
from collections import Counter from collections import Counter
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple, Union from typing import Callable, Dict, List, Tuple, Union
import numpy as np import numpy as np
@ -10,6 +10,11 @@ from TTS.tts.datasets.formatters import *
def split_dataset(items): def split_dataset(items):
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
Args:
items (List[List]): A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`.
"""
speakers = [item[-1] for item in items] speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1 is_multi_speaker = len(set(speakers)) > 1
eval_split_size = min(500, int(len(items) * 0.01)) eval_split_size = min(500, int(len(items) * 0.01))
@ -31,15 +36,23 @@ def split_dataset(items):
return items[:eval_split_size], items[eval_split_size:] return items[:eval_split_size], items[eval_split_size:]
def load_tts_samples(datasets: Union[List[Dict], Dict], eval_split=True) -> Tuple[List[List], List[List]]: def load_tts_samples(datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable=None) -> Tuple[List[List], List[List]]:
"""Parse the dataset, load the samples as a list and load the attention alignments if provided. """Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided.
If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based
on the dataset name.
Args: Args:
datasets (List[Dict], Dict): A list of datasets or a single dataset dictionary. If multiple datasets are datasets (List[Dict], Dict): A list of datasets or a single dataset dictionary. If multiple datasets are
in the list, they are all merged. in the list, they are all merged.
eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate
an eval split automatically. Defaults to True. an eval split automatically. Defaults to True.
formatter (Callable, optional): The preprocessing function to be applied to create the list of samples. It
must take the root_path and the meta_file name and return a list of samples in the format of
`[[audio_path, text, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as
example. Defaults to None.
Returns: Returns:
Tuple[List[List], List[List]: training and evaluation splits of the dataset. Tuple[List[List], List[List]: training and evaluation splits of the dataset.
""" """
@ -53,14 +66,15 @@ def load_tts_samples(datasets: Union[List[Dict], Dict], eval_split=True) -> Tupl
meta_file_train = dataset["meta_file_train"] meta_file_train = dataset["meta_file_train"]
meta_file_val = dataset["meta_file_val"] meta_file_val = dataset["meta_file_val"]
# setup the right data processor # setup the right data processor
preprocessor = _get_preprocessor_by_name(name) if formatter is None:
formatter = _get_formatter_by_name(name)
# load train set # load train set
meta_data_train = preprocessor(root_path, meta_file_train) meta_data_train = formatter(root_path, meta_file_train)
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
# load evaluation split if set # load evaluation split if set
if eval_split: if eval_split:
if meta_file_val: if meta_file_val:
meta_data_eval = preprocessor(root_path, meta_file_val) meta_data_eval = formatter(root_path, meta_file_val)
else: else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train) meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_eval_all += meta_data_eval meta_data_eval_all += meta_data_eval
@ -90,7 +104,7 @@ def load_attention_mask_meta_data(metafile_path):
return meta_data return meta_data
def _get_preprocessor_by_name(name): def _get_formatter_by_name(name):
"""Returns the respective preprocessing function.""" """Returns the respective preprocessing function."""
thismodule = sys.modules[__name__] thismodule = sys.modules[__name__]
return getattr(thismodule, name.lower()) return getattr(thismodule, name.lower())