mirror of https://github.com/coqui-ai/TTS.git
368 lines
15 KiB
Python
368 lines
15 KiB
Python
import collections
|
|
import os
|
|
import random
|
|
from multiprocessing import Manager, Pool
|
|
|
|
import numpy as np
|
|
import torch
|
|
import tqdm
|
|
from torch.utils.data import Dataset
|
|
from TTS.tts.utils.data import (prepare_data, prepare_stop_target,
|
|
prepare_tensor)
|
|
from TTS.tts.utils.text import (pad_with_eos_bos, phoneme_to_sequence,
|
|
text_to_sequence)
|
|
|
|
|
|
class MyDataset(Dataset):
|
|
def __init__(self,
|
|
outputs_per_step,
|
|
text_cleaner,
|
|
compute_linear_spec,
|
|
ap,
|
|
meta_data,
|
|
tp=None,
|
|
add_blank=False,
|
|
batch_group_size=0,
|
|
min_seq_len=0,
|
|
max_seq_len=float("inf"),
|
|
use_phonemes=True,
|
|
phoneme_cache_path=None,
|
|
phoneme_language="en-us",
|
|
enable_eos_bos=False,
|
|
speaker_mapping=None,
|
|
use_noise_augment=False,
|
|
verbose=False):
|
|
"""
|
|
Args:
|
|
outputs_per_step (int): number of time frames predicted per step.
|
|
text_cleaner (str): text cleaner used for the dataset.
|
|
compute_linear_spec (bool): compute linear spectrogram if True.
|
|
ap (TTS.tts.utils.AudioProcessor): audio processor object.
|
|
meta_data (list): list of dataset instances.
|
|
batch_group_size (int): (0) range of batch randomization after sorting
|
|
sequences by length.
|
|
min_seq_len (int): (0) minimum sequence length to be processed
|
|
by the loader.
|
|
max_seq_len (int): (float("inf")) maximum sequence length.
|
|
use_phonemes (bool): (true) if true, text converted to phonemes.
|
|
phoneme_cache_path (str): path to cache phoneme features.
|
|
phoneme_language (str): one the languages from
|
|
https://github.com/bootphon/phonemizer#languages
|
|
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
|
|
use_noise_augment (bool): enable adding random noise to wav for augmentation.
|
|
verbose (bool): print diagnostic information.
|
|
"""
|
|
self.batch_group_size = batch_group_size
|
|
self.items = meta_data
|
|
self.outputs_per_step = outputs_per_step
|
|
self.sample_rate = ap.sample_rate
|
|
self.cleaners = text_cleaner
|
|
self.compute_linear_spec = compute_linear_spec
|
|
self.min_seq_len = min_seq_len
|
|
self.max_seq_len = max_seq_len
|
|
self.ap = ap
|
|
self.tp = tp
|
|
self.add_blank = add_blank
|
|
self.use_phonemes = use_phonemes
|
|
self.phoneme_cache_path = phoneme_cache_path
|
|
self.phoneme_language = phoneme_language
|
|
self.enable_eos_bos = enable_eos_bos
|
|
self.speaker_mapping = speaker_mapping
|
|
self.use_noise_augment = use_noise_augment
|
|
self.verbose = verbose
|
|
self.input_seq_computed = False
|
|
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
|
os.makedirs(phoneme_cache_path, exist_ok=True)
|
|
if self.verbose:
|
|
print("\n > DataLoader initialization")
|
|
print(" | > Use phonemes: {}".format(self.use_phonemes))
|
|
if use_phonemes:
|
|
print(" | > phoneme language: {}".format(phoneme_language))
|
|
print(" | > Number of instances : {}".format(len(self.items)))
|
|
|
|
def load_wav(self, filename):
|
|
audio = self.ap.load_wav(filename)
|
|
return audio
|
|
|
|
@staticmethod
|
|
def load_np(filename):
|
|
data = np.load(filename).astype('float32')
|
|
return data
|
|
|
|
@staticmethod
|
|
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners,
|
|
language, tp, add_blank):
|
|
"""generate a phoneme sequence from text.
|
|
since the usage is for subsequent caching, we never add bos and
|
|
eos chars here. Instead we add those dynamically later; based on the
|
|
config option."""
|
|
phonemes = phoneme_to_sequence(text, [cleaners],
|
|
language=language,
|
|
enable_eos_bos=False,
|
|
tp=tp,
|
|
add_blank=add_blank)
|
|
phonemes = np.asarray(phonemes, dtype=np.int32)
|
|
np.save(cache_path, phonemes)
|
|
return phonemes
|
|
|
|
@staticmethod
|
|
def _load_or_generate_phoneme_sequence(wav_file, text, phoneme_cache_path,
|
|
enable_eos_bos, cleaners, language,
|
|
tp, add_blank):
|
|
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
|
|
|
# different names for normal phonemes and with blank chars.
|
|
file_name_ext = '_blanked_phoneme.npy' if add_blank else '_phoneme.npy'
|
|
cache_path = os.path.join(phoneme_cache_path,
|
|
file_name + file_name_ext)
|
|
try:
|
|
phonemes = np.load(cache_path)
|
|
except FileNotFoundError:
|
|
phonemes = MyDataset._generate_and_cache_phoneme_sequence(
|
|
text, cache_path, cleaners, language, tp, add_blank)
|
|
except (ValueError, IOError):
|
|
print(" [!] failed loading phonemes for {}. "
|
|
"Recomputing.".format(wav_file))
|
|
phonemes = MyDataset._generate_and_cache_phoneme_sequence(
|
|
text, cache_path, cleaners, language, tp, add_blank)
|
|
if enable_eos_bos:
|
|
phonemes = pad_with_eos_bos(phonemes, tp=tp)
|
|
phonemes = np.asarray(phonemes, dtype=np.int32)
|
|
return phonemes
|
|
|
|
def load_data(self, idx):
|
|
item = self.items[idx]
|
|
|
|
if len(item) == 4:
|
|
text, wav_file, speaker_name, attn_file = item
|
|
else:
|
|
text, wav_file, speaker_name = item
|
|
attn = None
|
|
|
|
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
|
|
|
# apply noise for augmentation
|
|
if self.use_noise_augment:
|
|
wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
|
|
|
|
if not self.input_seq_computed:
|
|
if self.use_phonemes:
|
|
text = self._load_or_generate_phoneme_sequence(
|
|
wav_file, text, self.phoneme_cache_path,
|
|
self.enable_eos_bos, self.cleaners, self.phoneme_language,
|
|
self.tp, self.add_blank)
|
|
|
|
else:
|
|
text = np.asarray(text_to_sequence(text, [self.cleaners],
|
|
tp=self.tp,
|
|
add_blank=self.add_blank),
|
|
dtype=np.int32)
|
|
|
|
assert text.size > 0, self.items[idx][1]
|
|
assert wav.size > 0, self.items[idx][1]
|
|
|
|
if "attn_file" in locals():
|
|
attn = np.load(attn_file)
|
|
|
|
if len(text) > self.max_seq_len:
|
|
# return a different sample if the phonemized
|
|
# text is longer than the threshold
|
|
# TODO: find a better fix
|
|
return self.load_data(100)
|
|
|
|
sample = {
|
|
'text': text,
|
|
'wav': wav,
|
|
'attn': attn,
|
|
'item_idx': self.items[idx][1],
|
|
'speaker_name': speaker_name,
|
|
'wav_file_name': os.path.basename(wav_file)
|
|
}
|
|
return sample
|
|
|
|
@staticmethod
|
|
def _phoneme_worker(args):
|
|
item = args[0]
|
|
func_args = args[1]
|
|
text, wav_file, *_ = item
|
|
phonemes = MyDataset._load_or_generate_phoneme_sequence(
|
|
wav_file, text, *func_args)
|
|
return phonemes
|
|
|
|
def compute_input_seq(self, num_workers=0):
|
|
"""compute input sequences separately. Call it before
|
|
passing dataset to data loader."""
|
|
if not self.use_phonemes:
|
|
if self.verbose:
|
|
print(" | > Computing input sequences ...")
|
|
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
|
text, *_ = item
|
|
sequence = np.asarray(text_to_sequence(
|
|
text, [self.cleaners],
|
|
tp=self.tp,
|
|
add_blank=self.add_blank),
|
|
dtype=np.int32)
|
|
self.items[idx][0] = sequence
|
|
|
|
else:
|
|
func_args = [
|
|
self.phoneme_cache_path, self.enable_eos_bos, self.cleaners,
|
|
self.phoneme_language, self.tp, self.add_blank
|
|
]
|
|
if self.verbose:
|
|
print(" | > Computing phonemes ...")
|
|
if num_workers == 0:
|
|
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
|
phonemes = self._phoneme_worker([item, func_args])
|
|
self.items[idx][0] = phonemes
|
|
else:
|
|
with Pool(num_workers) as p:
|
|
phonemes = list(
|
|
tqdm.tqdm(p.imap(MyDataset._phoneme_worker,
|
|
[[item, func_args]
|
|
for item in self.items]),
|
|
total=len(self.items)))
|
|
for idx, p in enumerate(phonemes):
|
|
self.items[idx][0] = p
|
|
|
|
def sort_items(self):
|
|
r"""Sort instances based on text length in ascending order"""
|
|
lengths = np.array([len(ins[0]) for ins in self.items])
|
|
|
|
idxs = np.argsort(lengths)
|
|
new_items = []
|
|
ignored = []
|
|
for i, idx in enumerate(idxs):
|
|
length = lengths[idx]
|
|
if length < self.min_seq_len or length > self.max_seq_len:
|
|
ignored.append(idx)
|
|
else:
|
|
new_items.append(self.items[idx])
|
|
# shuffle batch groups
|
|
if self.batch_group_size > 0:
|
|
for i in range(len(new_items) // self.batch_group_size):
|
|
offset = i * self.batch_group_size
|
|
end_offset = offset + self.batch_group_size
|
|
temp_items = new_items[offset:end_offset]
|
|
random.shuffle(temp_items)
|
|
new_items[offset:end_offset] = temp_items
|
|
self.items = new_items
|
|
|
|
if self.verbose:
|
|
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
|
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
|
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
|
|
print(
|
|
" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}"
|
|
.format(self.max_seq_len, self.min_seq_len, len(ignored)))
|
|
print(" | > Batch group size: {}.".format(self.batch_group_size))
|
|
|
|
def __len__(self):
|
|
return len(self.items)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.load_data(idx)
|
|
|
|
def collate_fn(self, batch):
|
|
r"""
|
|
Perform preprocessing and create a final data batch:
|
|
1. Sort batch instances by text-length
|
|
2. Convert Audio signal to Spectrograms.
|
|
3. PAD sequences wrt r.
|
|
4. Load to Torch.
|
|
"""
|
|
|
|
# Puts each data field into a tensor with outer dimension batch size
|
|
if isinstance(batch[0], collections.Mapping):
|
|
|
|
text_lenghts = np.array([len(d["text"]) for d in batch])
|
|
|
|
# sort items with text input length for RNN efficiency
|
|
text_lenghts, ids_sorted_decreasing = torch.sort(
|
|
torch.LongTensor(text_lenghts), dim=0, descending=True)
|
|
|
|
wav = [batch[idx]['wav'] for idx in ids_sorted_decreasing]
|
|
item_idxs = [
|
|
batch[idx]['item_idx'] for idx in ids_sorted_decreasing
|
|
]
|
|
text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
|
|
|
|
speaker_name = [
|
|
batch[idx]['speaker_name'] for idx in ids_sorted_decreasing
|
|
]
|
|
# get speaker embeddings
|
|
if self.speaker_mapping is not None:
|
|
wav_files_names = [
|
|
batch[idx]['wav_file_name']
|
|
for idx in ids_sorted_decreasing
|
|
]
|
|
speaker_embedding = [
|
|
self.speaker_mapping[w]['embedding']
|
|
for w in wav_files_names
|
|
]
|
|
else:
|
|
speaker_embedding = None
|
|
# compute features
|
|
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
|
|
|
mel_lengths = [m.shape[1] for m in mel]
|
|
|
|
# compute 'stop token' targets
|
|
stop_targets = [
|
|
np.array([0.] * (mel_len - 1) + [1.])
|
|
for mel_len in mel_lengths
|
|
]
|
|
|
|
# PAD stop targets
|
|
stop_targets = prepare_stop_target(stop_targets,
|
|
self.outputs_per_step)
|
|
|
|
# PAD sequences with longest instance in the batch
|
|
text = prepare_data(text).astype(np.int32)
|
|
|
|
# PAD features with longest instance
|
|
mel = prepare_tensor(mel, self.outputs_per_step)
|
|
|
|
# B x D x T --> B x T x D
|
|
mel = mel.transpose(0, 2, 1)
|
|
|
|
# convert things to pytorch
|
|
text_lenghts = torch.LongTensor(text_lenghts)
|
|
text = torch.LongTensor(text)
|
|
mel = torch.FloatTensor(mel).contiguous()
|
|
mel_lengths = torch.LongTensor(mel_lengths)
|
|
stop_targets = torch.FloatTensor(stop_targets)
|
|
|
|
if speaker_embedding is not None:
|
|
speaker_embedding = torch.FloatTensor(speaker_embedding)
|
|
|
|
# compute linear spectrogram
|
|
if self.compute_linear_spec:
|
|
linear = [
|
|
self.ap.spectrogram(w).astype('float32') for w in wav
|
|
]
|
|
linear = prepare_tensor(linear, self.outputs_per_step)
|
|
linear = linear.transpose(0, 2, 1)
|
|
assert mel.shape[1] == linear.shape[1]
|
|
linear = torch.FloatTensor(linear).contiguous()
|
|
else:
|
|
linear = None
|
|
|
|
# collate attention alignments
|
|
if batch[0]['attn'] is not None:
|
|
attns = [batch[idx]['attn'].T for idx in ids_sorted_decreasing]
|
|
for idx, attn in enumerate(attns):
|
|
pad2 = mel.shape[1] - attn.shape[1]
|
|
pad1 = text.shape[1] - attn.shape[0]
|
|
attn = np.pad(attn, [[0, pad1], [0, pad2]])
|
|
attns[idx] = attn
|
|
attns = prepare_tensor(attns, self.outputs_per_step)
|
|
attns = torch.FloatTensor(attns).unsqueeze(1)
|
|
else:
|
|
attns = None
|
|
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
|
|
stop_targets, item_idxs, speaker_embedding, attns
|
|
|
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
|
found {}".format(type(batch[0]))))
|