mirror of https://github.com/coqui-ai/TTS.git
Implement tarring datasets
This commit is contained in:
parent
42f77e7185
commit
2447f42ca1
|
@ -0,0 +1,134 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
import tarfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from coqpit import Coqpit
|
||||
from multiprocessing import Pool
|
||||
|
||||
from TTS.stt.datasets import load_stt_samples
|
||||
from TTS.config import BaseDatasetConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvertedArgs(Coqpit):
|
||||
dataset_name: List[str] = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "Name of the dataset(s) or the dataset format(s). Provided name(s) must be implemented in `stt.datasets.formatters`."
|
||||
},
|
||||
)
|
||||
dataset_path: List[str] = field(default_factory=list, metadata={"help": "Path(s) to the dataset(s)."})
|
||||
output_path: str = field(default="", metadata={"help": "Path to the output directory to save the tar shards."})
|
||||
num_shards: int = field(default=-1, metadata={"help": "Number of tarballs to create."})
|
||||
shuffle: bool = field(default=False, metadata={"help": "Shuffle the samples before tarring."})
|
||||
num_workers: int = field(default=1, metadata={"help": "Number of workers to use for parallelization."})
|
||||
|
||||
|
||||
@dataclass
|
||||
class TarMetadata(Coqpit):
|
||||
args: ConvertedArgs = field(default_factory=ConvertedArgs)
|
||||
created_date: str = field(default="", metadata={"help": "Date of creation of the tarball dataset."})
|
||||
num_samples_per_shard: int = field(default=0, metadata={"help": "Number of samples per tarball."})
|
||||
|
||||
def __post_init__(self):
|
||||
self.created_date = self.get_date()
|
||||
|
||||
@staticmethod
|
||||
def get_date():
|
||||
datetime.now().strftime("%m-%d-%Y %H-%M-%S")
|
||||
|
||||
|
||||
def create_tar_shard(params):
|
||||
samples = params[0]
|
||||
output_path = params[1]
|
||||
shard_no = params[2]
|
||||
|
||||
sharded_samples = []
|
||||
with tarfile.open(os.path.join(output_path, f'audio_{shard_no}.tar'), mode='w') as tar:
|
||||
count = {}
|
||||
for sample in samples:
|
||||
# We squash the filename since we do not preserve directory structure of audio files in the tarball.
|
||||
base, ext = os.path.splitext(sample['audio_file'])
|
||||
base = base.replace('/', '_')
|
||||
# Need the following replacement as long as WebDataset splits on first period
|
||||
base = base.replace('.', '_')
|
||||
squashed_filename = f'{base}{ext}'
|
||||
if squashed_filename not in count:
|
||||
tar.add(sample['audio_file'], arcname=squashed_filename)
|
||||
|
||||
if "duration" in sample:
|
||||
duration = sample['duration']
|
||||
else:
|
||||
# TODO: not sure if this returns the right value
|
||||
duration = os.path.getsize(sample["audio_file"])
|
||||
count[squashed_filename] = 1
|
||||
sharded_sample = {
|
||||
'audio_file': squashed_filename,
|
||||
'duration': duration,
|
||||
'text': sample['text'],
|
||||
'shard_no': shard_no, # Keep shard ID for recordkeeping
|
||||
}
|
||||
sharded_samples.append(sharded_sample)
|
||||
return sharded_samples
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# parse command line arguments
|
||||
args = ConvertedArgs()
|
||||
args.parse_args(arg_prefix="")
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
|
||||
# create tarring metadata config
|
||||
metadata_config = TarMetadata(args=args)
|
||||
|
||||
# create dataset configs
|
||||
dataset_configs = []
|
||||
for dataset_name, dataset_path in zip(args.dataset_name, args.dataset_path):
|
||||
dataset_config = BaseDatasetConfig(name=dataset_name, path=dataset_path)
|
||||
dataset_configs.append(dataset_config)
|
||||
|
||||
# load dataset samples
|
||||
samples, _ = load_stt_samples(dataset_configs, eval_split=False)
|
||||
print(f" > Number of data samples: {len(samples)}")
|
||||
|
||||
# shuffle samples
|
||||
if args.shuffle:
|
||||
print(" > Shuffling data samples...")
|
||||
random.shuffle(samples)
|
||||
|
||||
# define shard sample indices
|
||||
start_indices = []
|
||||
end_indices = []
|
||||
shard_size = (len(samples) // args.num_shards)
|
||||
for i in range(args.num_shards):
|
||||
start_idx = shard_size * i
|
||||
end_idx = start_idx + shard_size
|
||||
print(f" > Shard {i}: {start_idx} --> {end_idx}")
|
||||
if end_idx > len(samples):
|
||||
# discard the last shard to keep shard size the same
|
||||
print(f"Have {len(samples) - end_idx} entries left over that will be discarded.")
|
||||
start_indices.append(start_idx)
|
||||
end_indices.append(end_idx)
|
||||
|
||||
# create shards
|
||||
with Pool(args.num_workers) as pool:
|
||||
process_samples = [samples[start_idx:end_idx] for start_idx, end_idx in zip(start_indices, end_indices)]
|
||||
process_args = zip(process_samples, [args.output_path]*args.num_shards, range(args.num_shards))
|
||||
sharded_samples = pool.map(create_tar_shard, process_args)
|
||||
sharded_samples = [sample for sharded_sample in sharded_samples for sample in sharded_sample]
|
||||
print(f" > Total number of files sharded: {len(sharded_samples)}")
|
||||
|
||||
# Write manifest
|
||||
metadata_path = os.path.join(args.output_path, 'coqui_tarred_dataset.json')
|
||||
with open(metadata_path, 'w', encoding="utf8") as m2:
|
||||
for entry in sharded_samples:
|
||||
json.dump(entry, m2)
|
||||
m2.write('\n')
|
||||
|
||||
# Write metadata (default metadata for new datasets)
|
||||
metadata_config.num_samples_per_shard = shard_size
|
||||
metadata_path = os.path.join(args.output_path, 'metadata.json')
|
||||
metadata_config.save_json(metadata_path)
|
|
@ -1,12 +1,12 @@
|
|||
from tqdm import tqdm
|
||||
import glob
|
||||
import os
|
||||
from multiprocessing import Pool
|
||||
|
||||
from TTS.stt.utils.download import download_url, extract_archive
|
||||
from TTS.stt.datasets.formatters import *
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.stt.utils.download import download_url, extract_archive
|
||||
|
||||
|
||||
def _resample_file(func_args):
|
||||
|
@ -68,4 +68,4 @@ def download_librispeech(path: str, split_name: str):
|
|||
|
||||
if __name__ == "__main__":
|
||||
# download_librispeech("/home/ubuntu/librispeech/", "train-clean-100")
|
||||
download_ljspeech("/home/ubuntu/ljspeech/", n_jobs=8)
|
||||
# download_ljspeech("/home/ubuntu/ljspeech/", n_jobs=8)
|
||||
|
|
|
@ -32,7 +32,7 @@ def librispeech(root_path, meta_files=None):
|
|||
_delimiter = " "
|
||||
_audio_ext = ".flac"
|
||||
items = []
|
||||
if meta_files is None:
|
||||
if meta_files is None or meta_files == "":
|
||||
meta_files = glob(f"{root_path}/**/*trans.txt", recursive=True)
|
||||
else:
|
||||
if isinstance(meta_files, str):
|
||||
|
|
|
@ -52,7 +52,7 @@ audio_config = BaseAudioConfig(
|
|||
config = DeepSpeechConfig(
|
||||
audio=audio_config,
|
||||
run_name="deepspeech_librispeech",
|
||||
batch_size=128,
|
||||
batch_size=64,
|
||||
eval_batch_size=16,
|
||||
batch_group_size=5,
|
||||
num_loader_workers=4,
|
||||
|
|
Loading…
Reference in New Issue