mirror of https://github.com/coqui-ai/TTS.git
Enable custom formatter in load_tts_samples
This commit is contained in:
parent
7c10574931
commit
0cac3f330a
|
@ -1,7 +1,7 @@
|
||||||
import sys
|
import sys
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Callable, Dict, List, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -10,6 +10,11 @@ from TTS.tts.datasets.formatters import *
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(items):
|
def split_dataset(items):
|
||||||
|
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items (List[List]): A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`.
|
||||||
|
"""
|
||||||
speakers = [item[-1] for item in items]
|
speakers = [item[-1] for item in items]
|
||||||
is_multi_speaker = len(set(speakers)) > 1
|
is_multi_speaker = len(set(speakers)) > 1
|
||||||
eval_split_size = min(500, int(len(items) * 0.01))
|
eval_split_size = min(500, int(len(items) * 0.01))
|
||||||
|
@ -31,15 +36,23 @@ def split_dataset(items):
|
||||||
return items[:eval_split_size], items[eval_split_size:]
|
return items[:eval_split_size], items[eval_split_size:]
|
||||||
|
|
||||||
|
|
||||||
def load_tts_samples(datasets: Union[List[Dict], Dict], eval_split=True) -> Tuple[List[List], List[List]]:
|
def load_tts_samples(datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable=None) -> Tuple[List[List], List[List]]:
|
||||||
"""Parse the dataset, load the samples as a list and load the attention alignments if provided.
|
"""Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided.
|
||||||
|
If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based
|
||||||
|
on the dataset name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
datasets (List[Dict], Dict): A list of datasets or a single dataset dictionary. If multiple datasets are
|
datasets (List[Dict], Dict): A list of datasets or a single dataset dictionary. If multiple datasets are
|
||||||
in the list, they are all merged.
|
in the list, they are all merged.
|
||||||
|
|
||||||
eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate
|
eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate
|
||||||
an eval split automatically. Defaults to True.
|
an eval split automatically. Defaults to True.
|
||||||
|
|
||||||
|
formatter (Callable, optional): The preprocessing function to be applied to create the list of samples. It
|
||||||
|
must take the root_path and the meta_file name and return a list of samples in the format of
|
||||||
|
`[[audio_path, text, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as
|
||||||
|
example. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[List[List], List[List]: training and evaluation splits of the dataset.
|
Tuple[List[List], List[List]: training and evaluation splits of the dataset.
|
||||||
"""
|
"""
|
||||||
|
@ -53,14 +66,15 @@ def load_tts_samples(datasets: Union[List[Dict], Dict], eval_split=True) -> Tupl
|
||||||
meta_file_train = dataset["meta_file_train"]
|
meta_file_train = dataset["meta_file_train"]
|
||||||
meta_file_val = dataset["meta_file_val"]
|
meta_file_val = dataset["meta_file_val"]
|
||||||
# setup the right data processor
|
# setup the right data processor
|
||||||
preprocessor = _get_preprocessor_by_name(name)
|
if formatter is None:
|
||||||
|
formatter = _get_formatter_by_name(name)
|
||||||
# load train set
|
# load train set
|
||||||
meta_data_train = preprocessor(root_path, meta_file_train)
|
meta_data_train = formatter(root_path, meta_file_train)
|
||||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||||
# load evaluation split if set
|
# load evaluation split if set
|
||||||
if eval_split:
|
if eval_split:
|
||||||
if meta_file_val:
|
if meta_file_val:
|
||||||
meta_data_eval = preprocessor(root_path, meta_file_val)
|
meta_data_eval = formatter(root_path, meta_file_val)
|
||||||
else:
|
else:
|
||||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||||
meta_data_eval_all += meta_data_eval
|
meta_data_eval_all += meta_data_eval
|
||||||
|
@ -90,7 +104,7 @@ def load_attention_mask_meta_data(metafile_path):
|
||||||
return meta_data
|
return meta_data
|
||||||
|
|
||||||
|
|
||||||
def _get_preprocessor_by_name(name):
|
def _get_formatter_by_name(name):
|
||||||
"""Returns the respective preprocessing function."""
|
"""Returns the respective preprocessing function."""
|
||||||
thismodule = sys.modules[__name__]
|
thismodule = sys.modules[__name__]
|
||||||
return getattr(thismodule, name.lower())
|
return getattr(thismodule, name.lower())
|
||||||
|
|
Loading…
Reference in New Issue