coqui-tts/TTS/tts/datasets/dataset.py

645 lines
25 KiB
Python

import collections
import os
import random
from multiprocessing import Pool
from typing import Dict, List
import numpy as np
import torch
import tqdm
from torch.utils.data import Dataset
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
from TTS.utils.audio import AudioProcessor
class TTSDataset(Dataset):
def __init__(
self,
outputs_per_step: int,
text_cleaner: list,
compute_linear_spec: bool,
ap: AudioProcessor,
meta_data: List[Dict],
compute_f0: bool = False,
f0_cache_path: str = None,
characters: Dict = None,
custom_symbols: List = None,
add_blank: bool = False,
return_wav: bool = False,
batch_group_size: int = 0,
min_seq_len: int = 0,
max_seq_len: int = float("inf"),
use_phonemes: bool = False,
phoneme_cache_path: str = None,
phoneme_language: str = "en-us",
enable_eos_bos: bool = False,
speaker_id_mapping: Dict = None,
d_vector_mapping: Dict = None,
language_id_mapping: Dict = None,
use_noise_augment: bool = False,
verbose: bool = False,
):
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
If you need something different, you can inherit and override.
Args:
outputs_per_step (int): Number of time frames predicted per step.
text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs.
compute_linear_spec (bool): compute linear spectrogram if True.
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
meta_data (list): List of dataset samples.
compute_f0 (bool): compute f0 if True. Defaults to False.
f0_cache_path (str): Path to store f0 cache. Defaults to None.
characters (dict): `dict` of custom text characters used for converting texts to sequences.
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
set of symbols need to pass it here. Defaults to `None`.
add_blank (bool): Add a special `blank` character after every other character. It helps some
models achieve better results. Defaults to false.
return_wav (bool): Return the waveform of the sample. Defaults to False.
batch_group_size (int): Range of batch randomization after sorting
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
batch. Set 0 to disable. Defaults to 0.
min_seq_len (int): Minimum input sequence length to be processed
by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a
minimum input length due to its architecture. Defaults to 0.
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
It helps for controlling the VRAM usage against long input sequences. Especially models with
RNN layers are sensitive to input length. Defaults to `Inf`.
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
separate file. Defaults to None.
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults
to False.
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
embedding layer. Defaults to None.
d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None.
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
verbose (bool): Print diagnostic information. Defaults to false.
"""
super().__init__()
self.batch_group_size = batch_group_size
self.items = meta_data
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav
self.compute_f0 = compute_f0
self.f0_cache_path = f0_cache_path
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.ap = ap
self.characters = characters
self.custom_symbols = custom_symbols
self.add_blank = add_blank
self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language
self.enable_eos_bos = enable_eos_bos
self.speaker_id_mapping = speaker_id_mapping
self.d_vector_mapping = d_vector_mapping
self.language_id_mapping = language_id_mapping
self.use_noise_augment = use_noise_augment
self.verbose = verbose
self.input_seq_computed = False
self.rescue_item_idx = 1
self.pitch_computed = False
if use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True)
if compute_f0:
self.pitch_extractor = PitchExtractor(self.items, verbose=verbose)
if self.verbose:
print("\n > DataLoader initialization")
print(" | > Use phonemes: {}".format(self.use_phonemes))
if use_phonemes:
print(" | > phoneme language: {}".format(phoneme_language))
print(" | > Number of instances : {}".format(len(self.items)))
def load_wav(self, filename):
audio = self.ap.load_wav(filename)
return audio
@staticmethod
def load_np(filename):
data = np.load(filename).astype("float32")
return data
@staticmethod
def _generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
):
"""generate a phoneme sequence from text.
since the usage is for subsequent caching, we never add bos and
eos chars here. Instead we add those dynamically later; based on the
config option."""
phonemes = phoneme_to_sequence(
text,
[cleaners],
language=language,
enable_eos_bos=False,
custom_symbols=custom_symbols,
tp=characters,
add_blank=add_blank,
)
phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes)
return phonemes
@staticmethod
def _load_or_generate_phoneme_sequence(
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank
):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
# different names for normal phonemes and with blank chars.
file_name_ext = "_blanked_phoneme.npy" if add_blank else "_phoneme.npy"
cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext)
try:
phonemes = np.load(cache_path)
except FileNotFoundError:
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
)
except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
)
if enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes, tp=characters)
phonemes = np.asarray(phonemes, dtype=np.int32)
return phonemes
def load_data(self, idx):
item = self.items[idx]
raw_text = item["text"]
wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32)
# apply noise for augmentation
if self.use_noise_augment:
wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
if not self.input_seq_computed:
if self.use_phonemes:
text = self._load_or_generate_phoneme_sequence(
item["audio_file"],
item["text"],
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
item["language"] if item["language"] else self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
)
else:
text = np.asarray(
text_to_sequence(
item["text"],
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32,
)
assert text.size > 0, self.items[idx]["audio_file"]
assert wav.size > 0, self.items[idx]["audio_file"]
attn = None
if "alignment_file" in item:
attn = np.load(item["alignment_file"])
if len(text) > self.max_seq_len:
# return a different sample if the phonemized
# text is longer than the threshold
# TODO: find a better fix
return self.load_data(self.rescue_item_idx)
pitch = None
if self.compute_f0:
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, item["audio_file"], self.f0_cache_path)
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
sample = {
"raw_text": raw_text,
"text": text,
"wav": wav,
"pitch": pitch,
"attn": attn,
"item_idx": item["audio_file"],
"speaker_name": item["speaker_name"],
"language_name": item["language"],
"wav_file_name": os.path.basename(item["audio_file"]),
}
return sample
@staticmethod
def _phoneme_worker(args):
item = args[0]
func_args = args[1]
func_args[3] = (
item["language"] if "language" in item and item["language"] else func_args[3]
) # override phoneme language if specified by the dataset formatter
phonemes = TTSDataset._load_or_generate_phoneme_sequence(item["audio_file"], item["text"], *func_args)
return phonemes
def compute_input_seq(self, num_workers=0):
"""Compute the input sequences with multi-processing.
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
if not self.use_phonemes:
if self.verbose:
print(" | > Computing input sequences ...")
for idx, item in enumerate(tqdm.tqdm(self.items)):
sequence = np.asarray(
text_to_sequence(
item["text"],
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32,
)
self.items[idx][0] = sequence
else:
func_args = [
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
]
if self.verbose:
print(" | > Computing phonemes ...")
if num_workers == 0:
for idx, item in enumerate(tqdm.tqdm(self.items)):
phonemes = self._phoneme_worker([item, func_args])
self.items[idx][0] = phonemes
else:
with Pool(num_workers) as p:
phonemes = list(
tqdm.tqdm(
p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]),
total=len(self.items),
)
)
for idx, p in enumerate(phonemes):
self.items[idx][0] = p
def sort_and_filter_items(self, by_audio_len=False):
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
range.
Args:
by_audio_len (bool): if True, sort by audio length else by text length.
"""
# compute the target sequence length
if by_audio_len:
lengths = []
for item in self.items:
lengths.append(os.path.getsize(item["audio_file"]) / 16 * 8) # assuming 16bit audio
lengths = np.array(lengths)
else:
lengths = np.array([len(ins["text"]) for ins in self.items])
idxs = np.argsort(lengths)
new_items = []
ignored = []
for i, idx in enumerate(idxs):
length = lengths[idx]
if length < self.min_seq_len or length > self.max_seq_len:
ignored.append(idx)
else:
new_items.append(self.items[idx])
# shuffle batch groups
if self.batch_group_size > 0:
for i in range(len(new_items) // self.batch_group_size):
offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size
temp_items = new_items[offset:end_offset]
random.shuffle(temp_items)
new_items[offset:end_offset] = temp_items
self.items = new_items
if self.verbose:
print(" | > Max length sequence: {}".format(np.max(lengths)))
print(" | > Min length sequence: {}".format(np.min(lengths)))
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
print(
" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format(
self.max_seq_len, self.min_seq_len, len(ignored)
)
)
print(" | > Batch group size: {}.".format(self.batch_group_size))
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
return self.load_data(idx)
@staticmethod
def _sort_batch(batch, text_lengths):
"""Sort the batch by the input text length for RNN efficiency.
Args:
batch (Dict): Batch returned by `__getitem__`.
text_lengths (List[int]): Lengths of the input character sequences.
"""
text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True)
batch = [batch[idx] for idx in ids_sorted_decreasing]
return batch, text_lengths, ids_sorted_decreasing
def collate_fn(self, batch):
r"""
Perform preprocessing and create a final data batch:
1. Sort batch instances by text-length
2. Convert Audio signal to features.
3. PAD sequences wrt r.
4. Load to Torch.
"""
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.abc.Mapping):
text_lengths = np.array([len(d["text"]) for d in batch])
# sort items with text input length for RNN efficiency
batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths)
# convert list of dicts to dict of lists
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
# get language ids from language names
if self.language_id_mapping is not None:
language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]]
else:
language_ids = None
# get pre-computed d-vectors
if self.d_vector_mapping is not None:
wav_files_names = list(batch["wav_file_name"])
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
else:
d_vectors = None
# get numerical speaker ids from speaker names
if self.speaker_id_mapping:
speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
else:
speaker_ids = None
# compute features
mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]]
mel_lengths = [m.shape[1] for m in mel]
# lengths adjusted by the reduction factor
mel_lengths_adjusted = [
m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step))
if m.shape[1] % self.outputs_per_step
else m.shape[1]
for m in mel
]
# compute 'stop token' targets
stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths]
# PAD stop targets
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
# PAD sequences with longest instance in the batch
text = prepare_data(batch["text"]).astype(np.int32)
# PAD features with longest instance
mel = prepare_tensor(mel, self.outputs_per_step)
# B x D x T --> B x T x D
mel = mel.transpose(0, 2, 1)
# convert things to pytorch
text_lengths = torch.LongTensor(text_lengths)
text = torch.LongTensor(text)
mel = torch.FloatTensor(mel).contiguous()
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
if d_vectors is not None:
d_vectors = torch.FloatTensor(d_vectors)
if speaker_ids is not None:
speaker_ids = torch.LongTensor(speaker_ids)
if language_ids is not None:
language_ids = torch.LongTensor(language_ids)
# compute linear spectrogram
if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
linear = prepare_tensor(linear, self.outputs_per_step)
linear = linear.transpose(0, 2, 1)
assert mel.shape[1] == linear.shape[1]
linear = torch.FloatTensor(linear).contiguous()
else:
linear = None
# format waveforms
wav_padded = None
if self.return_wav:
wav_lengths = [w.shape[0] for w in batch["wav"]]
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
wav_lengths = torch.LongTensor(wav_lengths)
wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
for i, w in enumerate(batch["wav"]):
mel_length = mel_lengths_adjusted[i]
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
w = w[: mel_length * self.ap.hop_length]
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
wav_padded.transpose_(1, 2)
# compute f0
# TODO: compare perf in collate_fn vs in load_data
if self.compute_f0:
pitch = prepare_data(batch["pitch"])
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
else:
pitch = None
# collate attention alignments
if batch["attn"][0] is not None:
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
for idx, attn in enumerate(attns):
pad2 = mel.shape[1] - attn.shape[1]
pad1 = text.shape[1] - attn.shape[0]
assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}"
attn = np.pad(attn, [[0, pad1], [0, pad2]])
attns[idx] = attn
attns = prepare_tensor(attns, self.outputs_per_step)
attns = torch.FloatTensor(attns).unsqueeze(1)
else:
attns = None
# TODO: return dictionary
return {
"text": text,
"text_lengths": text_lengths,
"speaker_names": batch["speaker_name"],
"linear": linear,
"mel": mel,
"mel_lengths": mel_lengths,
"stop_targets": stop_targets,
"item_idxs": batch["item_idx"],
"d_vectors": d_vectors,
"speaker_ids": speaker_ids,
"attns": attns,
"waveform": wav_padded,
"raw_text": batch["raw_text"],
"pitch": pitch,
"language_ids": language_ids,
}
raise TypeError(
(
"batch must contain tensors, numbers, dicts or lists;\
found {}".format(
type(batch[0])
)
)
)
class PitchExtractor:
"""Pitch Extractor for computing F0 from wav files.
Args:
items (List[List]): Dataset samples.
verbose (bool): Whether to print the progress.
"""
def __init__(
self,
items: List[Dict],
verbose=False,
):
self.items = items
self.verbose = verbose
self.mean = None
self.std = None
@staticmethod
def create_pitch_file_path(wav_file, cache_path):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
return pitch_file
@staticmethod
def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
wav = ap.load_wav(wav_file)
pitch = ap.compute_f0(wav)
if pitch_file:
np.save(pitch_file, pitch)
return pitch
@staticmethod
def compute_pitch_stats(pitch_vecs):
nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs])
mean, std = np.mean(nonzeros), np.std(nonzeros)
return mean, std
def normalize_pitch(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0]
pitch = pitch - self.mean
pitch = pitch / self.std
pitch[zero_idxs] = 0.0
return pitch
def denormalize_pitch(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0]
pitch *= self.std
pitch += self.mean
pitch[zero_idxs] = 0.0
return pitch
@staticmethod
def load_or_compute_pitch(ap, wav_file, cache_path):
"""
compute pitch and return a numpy array of pitch values
"""
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
if not os.path.exists(pitch_file):
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
else:
pitch = np.load(pitch_file)
return pitch.astype(np.float32)
@staticmethod
def _pitch_worker(args):
item = args[0]
ap = args[1]
cache_path = args[2]
pitch_file = PitchExtractor.create_pitch_file_path(item["audio_file"], cache_path)
if not os.path.exists(pitch_file):
pitch = PitchExtractor._compute_and_save_pitch(ap, item["audio_file"], pitch_file)
return pitch
return None
def compute_pitch(self, ap, cache_path, num_workers=0):
"""Compute the input sequences with multi-processing.
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
if not os.path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
if self.verbose:
print(" | > Computing pitch features ...")
if num_workers == 0:
pitch_vecs = []
for _, item in enumerate(tqdm.tqdm(self.items)):
pitch_vecs += [self._pitch_worker([item, ap, cache_path])]
else:
with Pool(num_workers) as p:
pitch_vecs = list(
tqdm.tqdm(
p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]),
total=len(self.items),
)
)
pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs)
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
def load_pitch_stats(self, cache_path):
stats_path = os.path.join(cache_path, "pitch_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32)
self.std = stats["std"].astype(np.float32)