Implement STTDataset

This commit is contained in:
Eren Gölge 2021-09-30 14:41:22 +00:00
parent 0c7a2eb948
commit 1787b2303a
2 changed files with 339 additions and 0 deletions

281
TTS/stt/datasets/dataset.py Normal file
View File

@ -0,0 +1,281 @@
import collections
import os
import random
from typing import List, Union
import numpy as np
import torch
import torchaudio
from torch.utils.data import Dataset
from TTS.stt.datasets.tokenizer import Tokenizer
from TTS.tts.utils.data import prepare_data, prepare_tensor
from TTS.utils.audio import AudioProcessor
class FeatureExtractor:
_FEATURE_NAMES = ["MFCC", "MEL", "LINEAR_SPEC"]
def __init__(self, feature_name: str, ap: AudioProcessor) -> None:
self.feature_name = feature_name
self.ap = ap
def __call__(self, waveform: np.ndarray):
if self.feature_name == "MFCC":
mfcc = self.ap.mfcc(waveform)
return (mfcc - mfcc.mean()) / (mfcc.std() + 1e-6)
elif self.feature_name == "MEL":
return self.ap.melspectrogram(waveform)
elif self.feature_name == "LINEAR_SPEC":
return self.ap.spectrogram(waveform)
else:
raise ValueError("[!] Unknown feature_name: {}".format(self.feature_name))
class STTDataset(Dataset):
def __init__(
self,
samples: List[List],
ap: AudioProcessor,
tokenizer: Tokenizer,
batch_group_size: int = 0,
sort_by_audio_len: bool = True,
min_seq_len: int = 0,
max_seq_len: int = float("inf"),
cache_path: str = None,
verbose: bool = False,
feature_extractor: Union[FeatureExtractor, str] = "MFCC",
):
super().__init__()
self.samples = samples
self.ap = ap
self.tokenizer = tokenizer
self.batch_group_size = batch_group_size
self.sample_rate = ap.sample_rate
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.cache_path = cache_path
self.verbose = verbose
self.rescue_item_idx = 1
if isinstance(feature_extractor, str):
self.feature_extractor = FeatureExtractor(feature_extractor, ap)
else:
self.feature_extractor = feature_extractor
if cache_path:
os.makedirs(cache_path, exist_ok=True)
if self.verbose:
print("\n > STT DataLoader initialization")
print(" | > Number of instances : {}".format(len(self.samples)))
self.sort_filter_bucket_samples(sort_by_audio_len)
def load_wav(self, filename):
audio = self.ap.load_wav(filename)
return audio
@staticmethod
def load_np(filename):
data = np.load(filename).astype("float32")
return data
def load_data(self, idx):
sample = self.samples[idx]
waveform = np.asarray(self.load_wav(sample["audio_file"]), dtype=np.float32)
# waveform, sr = torchaudio.load(sample["audio_file"])
# waveform = torchaudio.functional.resample(waveform, sr, self.ap.sample_rate)[0].numpy()
if len(sample["text"]) > self.max_seq_len:
# return a different sample if the phonemized
# text is longer than the threshold
# TODO: find a better fix
return self.load_data(self.rescue_item_idx)
tokens = self.tokenizer.encode(sample["text"])
return_sample = {
"tokens": np.array(tokens),
"token_length": len(tokens),
"text": sample["text"],
"waveform": waveform,
"audio_file": sample["audio_file"],
"speaker_name": sample["speaker_name"],
}
return return_sample
def sort_filter_bucket_samples(self, by_audio_len=False):
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
range.
Args:
by_audio_len (bool): if True, sort by audio length else by text length.
"""
# compute the target sequence length
if by_audio_len:
lengths = []
for sample in self.samples:
lengths.append(os.path.getsize(sample["audio_file"]))
lengths = np.array(lengths)
else:
lengths = np.array([len(smp[0]) for smp in self.samples])
# sort items based on the sequence length in ascending order
idxs = np.argsort(lengths)
sorted_samples = []
ignored = []
for i, idx in enumerate(idxs):
length = lengths[idx]
if length < self.min_seq_len or length > self.max_seq_len:
ignored.append(idx)
else:
sorted_samples.append(self.samples[idx])
# shuffle batch groups
# create batches with similar length items
# the larger the `batch_group_size`, the higher the length variety in a batch.
if self.batch_group_size > 0:
for i in range(len(sorted_samples) // self.batch_group_size):
offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size
temp_items = sorted_samples[offset:end_offset]
random.shuffle(temp_items)
sorted_samples[offset:end_offset] = temp_items
if len(sorted_samples) == 0:
raise RuntimeError(" [!] No items left after filtering.")
# update items to the new sorted items
self.samples = sorted_samples
# logging
if self.verbose:
print(" | > Max length sequence: {}".format(np.max(lengths)))
print(" | > Min length sequence: {}".format(np.min(lengths)))
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
print(
" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format(
self.max_seq_len, self.min_seq_len, len(ignored)
)
)
print(" | > Batch group size: {}.".format(self.batch_group_size))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.load_data(idx)
@staticmethod
def _sort_batch(batch, text_lengths):
"""Sort the batch by the input text lengths.
Args:
batch (Dict): Batch returned by `__getitem__`.
text_lengths (List[int]): Lengths of the input character sequences.
"""
text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True)
batch = [batch[idx] for idx in ids_sorted_decreasing]
return batch, text_lengths, ids_sorted_decreasing
def collate_fn(self, batch):
r"""
Perform preprocessing and create a final data batch:
1. Sort batch instances by text-length
2. Convert Audio signal to features.
3. PAD sequences wrt r.
4. Load to Torch.
"""
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.abc.Mapping):
text_lengths = np.array([len(d["text"]) for d in batch])
# sort items with text input length for RNN efficiency
batch, text_lengths, _ = self._sort_batch(batch, text_lengths)
# convert list of dicts to dict of lists
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
# compute features
features = [self.feature_extractor(w).astype("float32") for w in batch["waveform"]]
feature_lengths = [m.shape[1] for m in features]
# PAD sequences with longest instance in the batch
tokens = prepare_data(batch["tokens"], pad_val=self.tokenizer.pad_token_id)
# PAD features with longest instance
features = prepare_tensor(features, 1)
# B x D x T --> B x T x D
features = features.transpose(0, 2, 1)
# convert things to pytorch
text_lengths = torch.LongTensor(text_lengths)
tokens = torch.LongTensor(tokens)
token_lengths = torch.LongTensor(batch["token_length"])
features = torch.FloatTensor(features).contiguous()
feature_lengths = torch.LongTensor(feature_lengths)
# format waveforms wrt to computed features
# TODO: not sure if this is good for MFCC
wav_lengths = [w.shape[0] for w in batch["waveform"]]
max_wav_len = max(feature_lengths) * self.ap.hop_length
wav_lengths = torch.LongTensor(wav_lengths)
wav_padded = torch.zeros(len(batch["waveform"]), 1, max_wav_len)
for i, w in enumerate(batch["waveform"]):
mel_length = feature_lengths[i]
w = np.pad(w, (0, self.ap.hop_length), mode="edge")
w = w[: mel_length * self.ap.hop_length]
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
wav_padded.transpose_(1, 2)
return {
"tokens": tokens,
"token_lengths": token_lengths,
"text": batch["text"],
"text_lengths": text_lengths,
"speaker_names": batch["speaker_name"],
"features": features, # (B x T x C)
"feature_lengths": feature_lengths, # (B)
"audio_file": batch["audio_file"],
"waveform": wav_padded, # (B x T x 1)
"raw_text": batch["text"],
}
raise TypeError(
(
"batch must contain tensors, numbers, dicts or lists;\
found {}".format(
type(batch[0])
)
)
)
if __name__ == "__main__":
from torch.utils.data.dataloader import DataLoader
from TTS.config import BaseAudioConfig
from TTS.stt.datasets.formatters import librispeech
from TTS.stt.datasets.tokenizer import Tokenizer, extract_vocab
samples = librispeech("/home/ubuntu/librispeech/LibriSpeech/train-clean-100/")
texts = [s["text"] for s in samples]
vocab, vocab_dict = extract_vocab(texts)
print(vocab)
tokenizer = Tokenizer(vocab_dict)
ap = AudioProcessor(**BaseAudioConfig(sample_rate=16000))
dataset = STTDataset(samples, ap=ap, tokenizer=tokenizer)
for batch in dataset:
print(batch["text"])
break
loader = DataLoader(dataset, batch_size=2, collate_fn=dataset.collate_fn)
for batch in loader:
print(batch)
breakpoint()

View File

@ -0,0 +1,58 @@
import os
import re
import xml.etree.ElementTree as ET
from glob import glob
from pathlib import Path
from typing import List
from tqdm import tqdm
########################
# DATASETS
########################
def ljspeech(root_path, meta_file):
"""Normalizes the LJSpeech meta data file to TTS format
https://keithito.com/LJ-Speech-Dataset/"""
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "ljspeech"
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
def librispeech(root_path, meta_files=None):
"""http://www.openslr.org/resources/12/"""
_delimiter = " "
_audio_ext = ".flac"
items = []
if meta_files is None:
meta_files = glob(f"{root_path}/**/*trans.txt", recursive=True)
else:
if isinstance(meta_files, str):
meta_files = [os.path.join(root_path, meta_files)]
for meta_file in meta_files:
_meta_file = os.path.basename(meta_file).split(".")[0]
with open(meta_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split(_delimiter, 1)
file_name = cols[0]
speaker_name, chapter_id, *_ = cols[0].split("-")
_root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}")
wav_file = os.path.join(_root_path, file_name + _audio_ext)
text = cols[1]
items.append({"text": text, "audio_file": wav_file, "speaker_name": "LibriSpeech_" + speaker_name})
for item in items:
assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}"
return items
if __name__ == "__main__":
items_ = librispeech(root_path="/home/ubuntu/librispeech/pforLibriSpeech/train-clean-100/", meta_files=None)