Refactor TTSDataset

Return a dict by `collate`
Refactor batch handling in `collate`
A couple of bug fixes
This commit is contained in:
Eren Gölge 2021-09-06 14:24:06 +00:00
parent 6878ddfbea
commit 91a70e80b2
4 changed files with 74 additions and 51 deletions

View File

@ -77,14 +77,14 @@ def set_filename(wav_path, out_path):
def format_data(data): def format_data(data):
# setup input data # setup input data
text_input = data[0] text_input = data['text']
text_lengths = data[1] text_lengths = data['text_lengths']
mel_input = data[4] mel_input = data['mel']
mel_lengths = data[5] mel_lengths = data['mel_lengths']
item_idx = data[7] item_idx = data['item_idxs']
d_vectors = data[8] d_vectors = data['d_vectors']
speaker_ids = data[9] speaker_ids = data['speaker_ids']
attn_mask = data[10] attn_mask = data['attns']
avg_text_length = torch.mean(text_lengths.float()) avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float())
@ -132,9 +132,8 @@ def inference(
speaker_c = speaker_ids speaker_c = speaker_ids
elif d_vectors is not None: elif d_vectors is not None:
speaker_c = d_vectors speaker_c = d_vectors
outputs = model.inference_with_MAS( outputs = model.inference_with_MAS(
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c} text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}
) )
model_output = outputs["model_outputs"] model_output = outputs["model_outputs"]
model_output = model_output.transpose(1, 2).detach().cpu().numpy() model_output = model_output.transpose(1, 2).detach().cpu().numpy()

View File

@ -271,7 +271,12 @@ class Trainer:
# setup scheduler # setup scheduler
self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer) self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer)
if self.scheduler is not None:
if self.args.continue_path: if self.args.continue_path:
if isinstance(self.scheduler, list):
for scheduler in self.scheduler:
scheduler.last_epoch = self.restore_step
else:
self.scheduler.last_epoch = self.restore_step self.scheduler.last_epoch = self.restore_step
# DISTRUBUTED # DISTRUBUTED

View File

@ -142,7 +142,7 @@ class BaseTTSConfig(BaseTrainingConfig):
enable / disable masking loss values against padded segments of samples in a batch. enable / disable masking loss values against padded segments of samples in a batch.
sort_by_audio_len (bool): sort_by_audio_len (bool):
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`. If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
min_seq_len (int): min_seq_len (int):
Minimum sequence length to be used at training. Minimum sequence length to be used at training.
@ -201,7 +201,7 @@ class BaseTTSConfig(BaseTrainingConfig):
batch_group_size: int = 0 batch_group_size: int = 0
loss_masking: bool = None loss_masking: bool = None
# dataloading # dataloading
sort_by_audio_len: bool = True sort_by_audio_len: bool = False
min_seq_len: int = 1 min_seq_len: int = 1
max_seq_len: int = float("inf") max_seq_len: int = float("inf")
compute_f0: bool = False compute_f0: bool = False

View File

@ -9,7 +9,7 @@ import torch
import tqdm import tqdm
from torch.utils.data import Dataset from torch.utils.data import Dataset
from TTS.tts.utils.data import _pad_data, prepare_data, prepare_stop_target, prepare_tensor from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -249,8 +249,8 @@ class TTSDataset(Dataset):
pitch = None pitch = None
if self.compute_f0: if self.compute_f0:
pitch = self.pitch_extractor._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
pitch = self.pitch_extractor.normalize_pitch(pitch) pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
sample = { sample = {
"raw_text": raw_text, "raw_text": raw_text,
@ -356,6 +356,11 @@ class TTSDataset(Dataset):
temp_items = new_items[offset:end_offset] temp_items = new_items[offset:end_offset]
random.shuffle(temp_items) random.shuffle(temp_items)
new_items[offset:end_offset] = temp_items new_items[offset:end_offset] = temp_items
if len(new_items) == 0:
raise RuntimeError(" [!] No items left after filtering.")
# update items to the new sorted items
self.items = new_items self.items = new_items
# logging # logging
@ -376,6 +381,18 @@ class TTSDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
return self.load_data(idx) return self.load_data(idx)
@staticmethod
def _sort_batch(batch, text_lengths):
"""Sort the batch by the input text length for RNN efficiency.
Args:
batch (Dict): Batch returned by `__getitem__`.
text_lengths (List[int]): Lengths of the input character sequences.
"""
text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True)
batch = [batch[idx] for idx in ids_sorted_decreasing]
return batch, text_lengths, ids_sorted_decreasing
def collate_fn(self, batch): def collate_fn(self, batch):
r""" r"""
Perform preprocessing and create a final data batch: Perform preprocessing and create a final data batch:
@ -388,30 +405,27 @@ class TTSDataset(Dataset):
# Puts each data field into a tensor with outer dimension batch size # Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.abc.Mapping): if isinstance(batch[0], collections.abc.Mapping):
text_lenghts = np.array([len(d["text"]) for d in batch]) text_lengths = np.array([len(d["text"]) for d in batch])
# sort items with text input length for RNN efficiency # sort items with text input length for RNN efficiency
text_lenghts, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lenghts), dim=0, descending=True) batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths)
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing] # convert list of dicts to dict of lists
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing] batch = {k: [dic[k] for dic in batch] for k in batch[0]}
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing]
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
# get pre-computed d-vectors # get pre-computed d-vectors
if self.d_vector_mapping is not None: if self.d_vector_mapping is not None:
wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing] wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing]
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names] d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
else: else:
d_vectors = None d_vectors = None
# get numerical speaker ids from speaker names # get numerical speaker ids from speaker names
if self.speaker_id_mapping: if self.speaker_id_mapping:
speaker_ids = [self.speaker_id_mapping[sn] for sn in speaker_names] speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
else: else:
speaker_ids = None speaker_ids = None
# compute features # compute features
mel = [self.ap.melspectrogram(w).astype("float32") for w in wav] mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]]
mel_lengths = [m.shape[1] for m in mel] mel_lengths = [m.shape[1] for m in mel]
@ -430,7 +444,7 @@ class TTSDataset(Dataset):
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
# PAD sequences with longest instance in the batch # PAD sequences with longest instance in the batch
text = prepare_data(text).astype(np.int32) text = prepare_data(batch["text"]).astype(np.int32)
# PAD features with longest instance # PAD features with longest instance
mel = prepare_tensor(mel, self.outputs_per_step) mel = prepare_tensor(mel, self.outputs_per_step)
@ -439,7 +453,7 @@ class TTSDataset(Dataset):
mel = mel.transpose(0, 2, 1) mel = mel.transpose(0, 2, 1)
# convert things to pytorch # convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts) text_lengths = torch.LongTensor(text_lengths)
text = torch.LongTensor(text) text = torch.LongTensor(text)
mel = torch.FloatTensor(mel).contiguous() mel = torch.FloatTensor(mel).contiguous()
mel_lengths = torch.LongTensor(mel_lengths) mel_lengths = torch.LongTensor(mel_lengths)
@ -453,7 +467,7 @@ class TTSDataset(Dataset):
# compute linear spectrogram # compute linear spectrogram
if self.compute_linear_spec: if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype("float32") for w in wav] linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
linear = prepare_tensor(linear, self.outputs_per_step) linear = prepare_tensor(linear, self.outputs_per_step)
linear = linear.transpose(0, 2, 1) linear = linear.transpose(0, 2, 1)
assert mel.shape[1] == linear.shape[1] assert mel.shape[1] == linear.shape[1]
@ -464,11 +478,11 @@ class TTSDataset(Dataset):
# format waveforms # format waveforms
wav_padded = None wav_padded = None
if self.return_wav: if self.return_wav:
wav_lengths = [w.shape[0] for w in wav] wav_lengths = [w.shape[0] for w in batch["wav"]]
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
wav_lengths = torch.LongTensor(wav_lengths) wav_lengths = torch.LongTensor(wav_lengths)
wav_padded = torch.zeros(len(batch), 1, max_wav_len) wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
for i, w in enumerate(wav): for i, w in enumerate(batch["wav"]):
mel_length = mel_lengths_adjusted[i] mel_length = mel_lengths_adjusted[i]
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge") w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
w = w[: mel_length * self.ap.hop_length] w = w[: mel_length * self.ap.hop_length]
@ -477,18 +491,16 @@ class TTSDataset(Dataset):
# compute f0 # compute f0
# TODO: compare perf in collate_fn vs in load_data # TODO: compare perf in collate_fn vs in load_data
pitch = None
if self.compute_f0: if self.compute_f0:
pitch = [b["pitch"] for b in batch] pitch = prepare_data(batch["pitch"])
pitch = prepare_data(pitch)
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
else: else:
pitch = None pitch = None
# collate attention alignments # collate attention alignments
if batch[0]["attn"] is not None: if batch["attn"][0] is not None:
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing] attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
for idx, attn in enumerate(attns): for idx, attn in enumerate(attns):
pad2 = mel.shape[1] - attn.shape[1] pad2 = mel.shape[1] - attn.shape[1]
pad1 = text.shape[1] - attn.shape[0] pad1 = text.shape[1] - attn.shape[0]
@ -502,18 +514,18 @@ class TTSDataset(Dataset):
# TODO: return dictionary # TODO: return dictionary
return { return {
"text": text, "text": text,
"text_lengths": text_lenghts, "text_lengths": text_lengths,
"speaker_names": speaker_names, "speaker_names": batch["speaker_name"],
"linear": linear, "linear": linear,
"mel": mel, "mel": mel,
"mel_lengths": mel_lengths, "mel_lengths": mel_lengths,
"stop_targets": stop_targets, "stop_targets": stop_targets,
"item_idxs": item_idxs, "item_idxs": batch["item_idx"],
"d_vectors": d_vectors, "d_vectors": d_vectors,
"speaker_ids": speaker_ids, "speaker_ids": speaker_ids,
"attns": attns, "attns": attns,
"waveform": wav_padded, "waveform": wav_padded,
"raw_text": raw_text, "raw_text": batch["raw_text"],
"pitch": pitch, "pitch": pitch,
} }
@ -567,13 +579,20 @@ class PitchExtractor:
def normalize_pitch(self, pitch): def normalize_pitch(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0] zero_idxs = np.where(pitch == 0.0)[0]
pitch -= self.mean pitch = pitch - self.mean
pitch /= self.std pitch = pitch / self.std
pitch[zero_idxs] = 0.0
return pitch
def denormalize_pitch(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0]
pitch *= self.std
pitch += self.mean
pitch[zero_idxs] = 0.0 pitch[zero_idxs] = 0.0
return pitch return pitch
@staticmethod @staticmethod
def _load_or_compute_pitch(ap, wav_file, cache_path): def load_or_compute_pitch(ap, wav_file, cache_path):
""" """
compute pitch and return a numpy array of pitch values compute pitch and return a numpy array of pitch values
""" """
@ -582,7 +601,7 @@ class PitchExtractor:
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
else: else:
pitch = np.load(pitch_file) pitch = np.load(pitch_file)
return pitch return pitch.astype(np.float32)
@staticmethod @staticmethod
def _pitch_worker(args): def _pitch_worker(args):
@ -596,7 +615,7 @@ class PitchExtractor:
return pitch return pitch
return None return None
def compute_pitch(self, cache_path, num_workers=0): def compute_pitch(self, ap, cache_path, num_workers=0):
"""Compute the input sequences with multi-processing. """Compute the input sequences with multi-processing.
Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
if not os.path.exists(cache_path): if not os.path.exists(cache_path):
@ -607,12 +626,12 @@ class PitchExtractor:
if num_workers == 0: if num_workers == 0:
pitch_vecs = [] pitch_vecs = []
for _, item in enumerate(tqdm.tqdm(self.items)): for _, item in enumerate(tqdm.tqdm(self.items)):
pitch_vecs += [self._pitch_worker([item, self.ap, cache_path])] pitch_vecs += [self._pitch_worker([item, ap, cache_path])]
else: else:
with Pool(num_workers) as p: with Pool(num_workers) as p:
pitch_vecs = list( pitch_vecs = list(
tqdm.tqdm( tqdm.tqdm(
p.imap(PitchExtractor._pitch_worker, [[item, self.ap, cache_path] for item in self.items]), p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]),
total=len(self.items), total=len(self.items),
) )
) )
@ -623,5 +642,5 @@ class PitchExtractor:
def load_pitch_stats(self, cache_path): def load_pitch_stats(self, cache_path):
stats_path = os.path.join(cache_path, "pitch_stats.npy") stats_path = os.path.join(cache_path, "pitch_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item() stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"] self.mean = stats["mean"].astype(np.float32)
self.std = stats["std"] self.std = stats["std"].astype(np.float32)