diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index c163a11d..942a1365 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -1,7 +1,7 @@ import sys from collections import Counter from pathlib import Path -from typing import Dict, List, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union import numpy as np @@ -10,6 +10,11 @@ from TTS.tts.datasets.formatters import * def split_dataset(items): + """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. + + Args: + items (List[List]): A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. + """ speakers = [item[-1] for item in items] is_multi_speaker = len(set(speakers)) > 1 eval_split_size = min(500, int(len(items) * 0.01)) @@ -31,15 +36,23 @@ def split_dataset(items): return items[:eval_split_size], items[eval_split_size:] -def load_tts_samples(datasets: Union[List[Dict], Dict], eval_split=True) -> Tuple[List[List], List[List]]: - """Parse the dataset, load the samples as a list and load the attention alignments if provided. +def load_tts_samples(datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable=None) -> Tuple[List[List], List[List]]: + """Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided. + If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based + on the dataset name. Args: datasets (List[Dict], Dict): A list of datasets or a single dataset dictionary. If multiple datasets are in the list, they are all merged. + eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate an eval split automatically. Defaults to True. + formatter (Callable, optional): The preprocessing function to be applied to create the list of samples. It + must take the root_path and the meta_file name and return a list of samples in the format of + `[[audio_path, text, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as + example. Defaults to None. + Returns: Tuple[List[List], List[List]: training and evaluation splits of the dataset. """ @@ -53,14 +66,15 @@ def load_tts_samples(datasets: Union[List[Dict], Dict], eval_split=True) -> Tupl meta_file_train = dataset["meta_file_train"] meta_file_val = dataset["meta_file_val"] # setup the right data processor - preprocessor = _get_preprocessor_by_name(name) + if formatter is None: + formatter = _get_formatter_by_name(name) # load train set - meta_data_train = preprocessor(root_path, meta_file_train) + meta_data_train = formatter(root_path, meta_file_train) print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") # load evaluation split if set if eval_split: if meta_file_val: - meta_data_eval = preprocessor(root_path, meta_file_val) + meta_data_eval = formatter(root_path, meta_file_val) else: meta_data_eval, meta_data_train = split_dataset(meta_data_train) meta_data_eval_all += meta_data_eval @@ -90,7 +104,7 @@ def load_attention_mask_meta_data(metafile_path): return meta_data -def _get_preprocessor_by_name(name): +def _get_formatter_by_name(name): """Returns the respective preprocessing function.""" thismodule = sys.modules[__name__] return getattr(thismodule, name.lower())