Enable custom formatter in load_tts_samples

This commit is contained in:
Eren Gölge 2021-10-26 13:06:22 +02:00
parent 7c10574931
commit 0cac3f330a
1 changed files with 21 additions and 7 deletions

View File

@ -1,7 +1,7 @@
import sys
from collections import Counter
from pathlib import Path
from typing import Dict, List, Tuple, Union
from typing import Callable, Dict, List, Tuple, Union
import numpy as np
@ -10,6 +10,11 @@ from TTS.tts.datasets.formatters import *
def split_dataset(items):
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
Args:
items (List[List]): A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`.
"""
speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = min(500, int(len(items) * 0.01))
@ -31,15 +36,23 @@ def split_dataset(items):
return items[:eval_split_size], items[eval_split_size:]
def load_tts_samples(datasets: Union[List[Dict], Dict], eval_split=True) -> Tuple[List[List], List[List]]:
"""Parse the dataset, load the samples as a list and load the attention alignments if provided.
def load_tts_samples(datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable=None) -> Tuple[List[List], List[List]]:
"""Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided.
If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based
on the dataset name.
Args:
datasets (List[Dict], Dict): A list of datasets or a single dataset dictionary. If multiple datasets are
in the list, they are all merged.
eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate
an eval split automatically. Defaults to True.
formatter (Callable, optional): The preprocessing function to be applied to create the list of samples. It
must take the root_path and the meta_file name and return a list of samples in the format of
`[[audio_path, text, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as
example. Defaults to None.
Returns:
Tuple[List[List], List[List]: training and evaluation splits of the dataset.
"""
@ -53,14 +66,15 @@ def load_tts_samples(datasets: Union[List[Dict], Dict], eval_split=True) -> Tupl
meta_file_train = dataset["meta_file_train"]
meta_file_val = dataset["meta_file_val"]
# setup the right data processor
preprocessor = _get_preprocessor_by_name(name)
if formatter is None:
formatter = _get_formatter_by_name(name)
# load train set
meta_data_train = preprocessor(root_path, meta_file_train)
meta_data_train = formatter(root_path, meta_file_train)
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
# load evaluation split if set
if eval_split:
if meta_file_val:
meta_data_eval = preprocessor(root_path, meta_file_val)
meta_data_eval = formatter(root_path, meta_file_val)
else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_eval_all += meta_data_eval
@ -90,7 +104,7 @@ def load_attention_mask_meta_data(metafile_path):
return meta_data
def _get_preprocessor_by_name(name):
def _get_formatter_by_name(name):
"""Returns the respective preprocessing function."""
thismodule = sys.modules[__name__]
return getattr(thismodule, name.lower())