Refactor TTSDataset

Return a dict by `collate`
Refactor batch handling in `collate`
A couple of bug fixes
This commit is contained in:
Eren Gölge 2021-09-06 14:24:06 +00:00
parent 6878ddfbea
commit 91a70e80b2
4 changed files with 74 additions and 51 deletions

View File

@ -77,14 +77,14 @@ def set_filename(wav_path, out_path):
def format_data(data):
# setup input data
text_input = data[0]
text_lengths = data[1]
mel_input = data[4]
mel_lengths = data[5]
item_idx = data[7]
d_vectors = data[8]
speaker_ids = data[9]
attn_mask = data[10]
text_input = data['text']
text_lengths = data['text_lengths']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
item_idx = data['item_idxs']
d_vectors = data['d_vectors']
speaker_ids = data['speaker_ids']
attn_mask = data['attns']
avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float())
@ -132,9 +132,8 @@ def inference(
speaker_c = speaker_ids
elif d_vectors is not None:
speaker_c = d_vectors
outputs = model.inference_with_MAS(
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c}
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}
)
model_output = outputs["model_outputs"]
model_output = model_output.transpose(1, 2).detach().cpu().numpy()

View File

@ -271,8 +271,13 @@ class Trainer:
# setup scheduler
self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer)
if self.args.continue_path:
self.scheduler.last_epoch = self.restore_step
if self.scheduler is not None:
if self.args.continue_path:
if isinstance(self.scheduler, list):
for scheduler in self.scheduler:
scheduler.last_epoch = self.restore_step
else:
self.scheduler.last_epoch = self.restore_step
# DISTRUBUTED
if self.num_gpus > 1:

View File

@ -142,7 +142,7 @@ class BaseTTSConfig(BaseTrainingConfig):
enable / disable masking loss values against padded segments of samples in a batch.
sort_by_audio_len (bool):
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`.
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
min_seq_len (int):
Minimum sequence length to be used at training.
@ -201,7 +201,7 @@ class BaseTTSConfig(BaseTrainingConfig):
batch_group_size: int = 0
loss_masking: bool = None
# dataloading
sort_by_audio_len: bool = True
sort_by_audio_len: bool = False
min_seq_len: int = 1
max_seq_len: int = float("inf")
compute_f0: bool = False

View File

@ -9,7 +9,7 @@ import torch
import tqdm
from torch.utils.data import Dataset
from TTS.tts.utils.data import _pad_data, prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
from TTS.utils.audio import AudioProcessor
@ -249,8 +249,8 @@ class TTSDataset(Dataset):
pitch = None
if self.compute_f0:
pitch = self.pitch_extractor._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
pitch = self.pitch_extractor.normalize_pitch(pitch)
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
sample = {
"raw_text": raw_text,
@ -356,6 +356,11 @@ class TTSDataset(Dataset):
temp_items = new_items[offset:end_offset]
random.shuffle(temp_items)
new_items[offset:end_offset] = temp_items
if len(new_items) == 0:
raise RuntimeError(" [!] No items left after filtering.")
# update items to the new sorted items
self.items = new_items
# logging
@ -376,6 +381,18 @@ class TTSDataset(Dataset):
def __getitem__(self, idx):
return self.load_data(idx)
@staticmethod
def _sort_batch(batch, text_lengths):
"""Sort the batch by the input text length for RNN efficiency.
Args:
batch (Dict): Batch returned by `__getitem__`.
text_lengths (List[int]): Lengths of the input character sequences.
"""
text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True)
batch = [batch[idx] for idx in ids_sorted_decreasing]
return batch, text_lengths, ids_sorted_decreasing
def collate_fn(self, batch):
r"""
Perform preprocessing and create a final data batch:
@ -388,30 +405,27 @@ class TTSDataset(Dataset):
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.abc.Mapping):
text_lenghts = np.array([len(d["text"]) for d in batch])
text_lengths = np.array([len(d["text"]) for d in batch])
# sort items with text input length for RNN efficiency
text_lenghts, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lenghts), dim=0, descending=True)
batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths)
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing]
# convert list of dicts to dict of lists
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
# get pre-computed d-vectors
if self.d_vector_mapping is not None:
wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing]
wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing]
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
else:
d_vectors = None
# get numerical speaker ids from speaker names
if self.speaker_id_mapping:
speaker_ids = [self.speaker_id_mapping[sn] for sn in speaker_names]
speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
else:
speaker_ids = None
# compute features
mel = [self.ap.melspectrogram(w).astype("float32") for w in wav]
mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]]
mel_lengths = [m.shape[1] for m in mel]
@ -430,7 +444,7 @@ class TTSDataset(Dataset):
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
# PAD sequences with longest instance in the batch
text = prepare_data(text).astype(np.int32)
text = prepare_data(batch["text"]).astype(np.int32)
# PAD features with longest instance
mel = prepare_tensor(mel, self.outputs_per_step)
@ -439,7 +453,7 @@ class TTSDataset(Dataset):
mel = mel.transpose(0, 2, 1)
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text_lengths = torch.LongTensor(text_lengths)
text = torch.LongTensor(text)
mel = torch.FloatTensor(mel).contiguous()
mel_lengths = torch.LongTensor(mel_lengths)
@ -453,7 +467,7 @@ class TTSDataset(Dataset):
# compute linear spectrogram
if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype("float32") for w in wav]
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
linear = prepare_tensor(linear, self.outputs_per_step)
linear = linear.transpose(0, 2, 1)
assert mel.shape[1] == linear.shape[1]
@ -464,11 +478,11 @@ class TTSDataset(Dataset):
# format waveforms
wav_padded = None
if self.return_wav:
wav_lengths = [w.shape[0] for w in wav]
wav_lengths = [w.shape[0] for w in batch["wav"]]
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
wav_lengths = torch.LongTensor(wav_lengths)
wav_padded = torch.zeros(len(batch), 1, max_wav_len)
for i, w in enumerate(wav):
wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
for i, w in enumerate(batch["wav"]):
mel_length = mel_lengths_adjusted[i]
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
w = w[: mel_length * self.ap.hop_length]
@ -477,18 +491,16 @@ class TTSDataset(Dataset):
# compute f0
# TODO: compare perf in collate_fn vs in load_data
pitch = None
if self.compute_f0:
pitch = [b["pitch"] for b in batch]
pitch = prepare_data(pitch)
pitch = prepare_data(batch["pitch"])
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
else:
pitch = None
# collate attention alignments
if batch[0]["attn"] is not None:
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
if batch["attn"][0] is not None:
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
for idx, attn in enumerate(attns):
pad2 = mel.shape[1] - attn.shape[1]
pad1 = text.shape[1] - attn.shape[0]
@ -502,18 +514,18 @@ class TTSDataset(Dataset):
# TODO: return dictionary
return {
"text": text,
"text_lengths": text_lenghts,
"speaker_names": speaker_names,
"text_lengths": text_lengths,
"speaker_names": batch["speaker_name"],
"linear": linear,
"mel": mel,
"mel_lengths": mel_lengths,
"stop_targets": stop_targets,
"item_idxs": item_idxs,
"item_idxs": batch["item_idx"],
"d_vectors": d_vectors,
"speaker_ids": speaker_ids,
"attns": attns,
"waveform": wav_padded,
"raw_text": raw_text,
"raw_text": batch["raw_text"],
"pitch": pitch,
}
@ -567,13 +579,20 @@ class PitchExtractor:
def normalize_pitch(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0]
pitch -= self.mean
pitch /= self.std
pitch = pitch - self.mean
pitch = pitch / self.std
pitch[zero_idxs] = 0.0
return pitch
def denormalize_pitch(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0]
pitch *= self.std
pitch += self.mean
pitch[zero_idxs] = 0.0
return pitch
@staticmethod
def _load_or_compute_pitch(ap, wav_file, cache_path):
def load_or_compute_pitch(ap, wav_file, cache_path):
"""
compute pitch and return a numpy array of pitch values
"""
@ -582,7 +601,7 @@ class PitchExtractor:
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
else:
pitch = np.load(pitch_file)
return pitch
return pitch.astype(np.float32)
@staticmethod
def _pitch_worker(args):
@ -596,7 +615,7 @@ class PitchExtractor:
return pitch
return None
def compute_pitch(self, cache_path, num_workers=0):
def compute_pitch(self, ap, cache_path, num_workers=0):
"""Compute the input sequences with multi-processing.
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
if not os.path.exists(cache_path):
@ -607,12 +626,12 @@ class PitchExtractor:
if num_workers == 0:
pitch_vecs = []
for _, item in enumerate(tqdm.tqdm(self.items)):
pitch_vecs += [self._pitch_worker([item, self.ap, cache_path])]
pitch_vecs += [self._pitch_worker([item, ap, cache_path])]
else:
with Pool(num_workers) as p:
pitch_vecs = list(
tqdm.tqdm(
p.imap(PitchExtractor._pitch_worker, [[item, self.ap, cache_path] for item in self.items]),
p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]),
total=len(self.items),
)
)
@ -623,5 +642,5 @@ class PitchExtractor:
def load_pitch_stats(self, cache_path):
stats_path = os.path.join(cache_path, "pitch_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"]
self.std = stats["std"]
self.mean = stats["mean"].astype(np.float32)
self.std = stats["std"].astype(np.float32)