From 91a70e80b20da7666e28ca14098e255afa32aa5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 6 Sep 2021 14:24:06 +0000 Subject: [PATCH] Refactor TTSDataset Return a dict by `collate` Refactor batch handling in `collate` A couple of bug fixes --- TTS/bin/extract_tts_spectrograms.py | 19 +++--- TTS/trainer.py | 9 ++- TTS/tts/configs/shared_configs.py | 4 +- TTS/tts/datasets/TTSDataset.py | 93 +++++++++++++++++------------ 4 files changed, 74 insertions(+), 51 deletions(-) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 6ec99fac..9f54cb39 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -77,14 +77,14 @@ def set_filename(wav_path, out_path): def format_data(data): # setup input data - text_input = data[0] - text_lengths = data[1] - mel_input = data[4] - mel_lengths = data[5] - item_idx = data[7] - d_vectors = data[8] - speaker_ids = data[9] - attn_mask = data[10] + text_input = data['text'] + text_lengths = data['text_lengths'] + mel_input = data['mel'] + mel_lengths = data['mel_lengths'] + item_idx = data['item_idxs'] + d_vectors = data['d_vectors'] + speaker_ids = data['speaker_ids'] + attn_mask = data['attns'] avg_text_length = torch.mean(text_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float()) @@ -132,9 +132,8 @@ def inference( speaker_c = speaker_ids elif d_vectors is not None: speaker_c = d_vectors - outputs = model.inference_with_MAS( - text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c} + text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids} ) model_output = outputs["model_outputs"] model_output = model_output.transpose(1, 2).detach().cpu().numpy() diff --git a/TTS/trainer.py b/TTS/trainer.py index 9bb5b096..bc9a49c6 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -271,8 +271,13 @@ class Trainer: # setup scheduler self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer) - if self.args.continue_path: - self.scheduler.last_epoch = self.restore_step + if self.scheduler is not None: + if self.args.continue_path: + if isinstance(self.scheduler, list): + for scheduler in self.scheduler: + scheduler.last_epoch = self.restore_step + else: + self.scheduler.last_epoch = self.restore_step # DISTRUBUTED if self.num_gpus > 1: diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index 3dc70786..e208c16c 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -142,7 +142,7 @@ class BaseTTSConfig(BaseTrainingConfig): enable / disable masking loss values against padded segments of samples in a batch. sort_by_audio_len (bool): - If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`. + If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`. min_seq_len (int): Minimum sequence length to be used at training. @@ -201,7 +201,7 @@ class BaseTTSConfig(BaseTrainingConfig): batch_group_size: int = 0 loss_masking: bool = None # dataloading - sort_by_audio_len: bool = True + sort_by_audio_len: bool = False min_seq_len: int = 1 max_seq_len: int = float("inf") compute_f0: bool = False diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 74cb8de1..c81e0e6c 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -9,7 +9,7 @@ import torch import tqdm from torch.utils.data import Dataset -from TTS.tts.utils.data import _pad_data, prepare_data, prepare_stop_target, prepare_tensor +from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence from TTS.utils.audio import AudioProcessor @@ -249,8 +249,8 @@ class TTSDataset(Dataset): pitch = None if self.compute_f0: - pitch = self.pitch_extractor._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) - pitch = self.pitch_extractor.normalize_pitch(pitch) + pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) + pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32)) sample = { "raw_text": raw_text, @@ -356,6 +356,11 @@ class TTSDataset(Dataset): temp_items = new_items[offset:end_offset] random.shuffle(temp_items) new_items[offset:end_offset] = temp_items + + if len(new_items) == 0: + raise RuntimeError(" [!] No items left after filtering.") + + # update items to the new sorted items self.items = new_items # logging @@ -376,6 +381,18 @@ class TTSDataset(Dataset): def __getitem__(self, idx): return self.load_data(idx) + @staticmethod + def _sort_batch(batch, text_lengths): + """Sort the batch by the input text length for RNN efficiency. + + Args: + batch (Dict): Batch returned by `__getitem__`. + text_lengths (List[int]): Lengths of the input character sequences. + """ + text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True) + batch = [batch[idx] for idx in ids_sorted_decreasing] + return batch, text_lengths, ids_sorted_decreasing + def collate_fn(self, batch): r""" Perform preprocessing and create a final data batch: @@ -388,30 +405,27 @@ class TTSDataset(Dataset): # Puts each data field into a tensor with outer dimension batch size if isinstance(batch[0], collections.abc.Mapping): - text_lenghts = np.array([len(d["text"]) for d in batch]) + text_lengths = np.array([len(d["text"]) for d in batch]) # sort items with text input length for RNN efficiency - text_lenghts, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lenghts), dim=0, descending=True) + batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths) - wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing] - item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing] - text = [batch[idx]["text"] for idx in ids_sorted_decreasing] - raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing] + # convert list of dicts to dict of lists + batch = {k: [dic[k] for dic in batch] for k in batch[0]} - speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing] # get pre-computed d-vectors if self.d_vector_mapping is not None: - wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing] + wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing] d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names] else: d_vectors = None # get numerical speaker ids from speaker names if self.speaker_id_mapping: - speaker_ids = [self.speaker_id_mapping[sn] for sn in speaker_names] + speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]] else: speaker_ids = None # compute features - mel = [self.ap.melspectrogram(w).astype("float32") for w in wav] + mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]] mel_lengths = [m.shape[1] for m in mel] @@ -430,7 +444,7 @@ class TTSDataset(Dataset): stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) # PAD sequences with longest instance in the batch - text = prepare_data(text).astype(np.int32) + text = prepare_data(batch["text"]).astype(np.int32) # PAD features with longest instance mel = prepare_tensor(mel, self.outputs_per_step) @@ -439,7 +453,7 @@ class TTSDataset(Dataset): mel = mel.transpose(0, 2, 1) # convert things to pytorch - text_lenghts = torch.LongTensor(text_lenghts) + text_lengths = torch.LongTensor(text_lengths) text = torch.LongTensor(text) mel = torch.FloatTensor(mel).contiguous() mel_lengths = torch.LongTensor(mel_lengths) @@ -453,7 +467,7 @@ class TTSDataset(Dataset): # compute linear spectrogram if self.compute_linear_spec: - linear = [self.ap.spectrogram(w).astype("float32") for w in wav] + linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] linear = prepare_tensor(linear, self.outputs_per_step) linear = linear.transpose(0, 2, 1) assert mel.shape[1] == linear.shape[1] @@ -464,11 +478,11 @@ class TTSDataset(Dataset): # format waveforms wav_padded = None if self.return_wav: - wav_lengths = [w.shape[0] for w in wav] + wav_lengths = [w.shape[0] for w in batch["wav"]] max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length wav_lengths = torch.LongTensor(wav_lengths) - wav_padded = torch.zeros(len(batch), 1, max_wav_len) - for i, w in enumerate(wav): + wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len) + for i, w in enumerate(batch["wav"]): mel_length = mel_lengths_adjusted[i] w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge") w = w[: mel_length * self.ap.hop_length] @@ -477,18 +491,16 @@ class TTSDataset(Dataset): # compute f0 # TODO: compare perf in collate_fn vs in load_data - pitch = None if self.compute_f0: - pitch = [b["pitch"] for b in batch] - pitch = prepare_data(pitch) + pitch = prepare_data(batch["pitch"]) assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT else: pitch = None # collate attention alignments - if batch[0]["attn"] is not None: - attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing] + if batch["attn"][0] is not None: + attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] for idx, attn in enumerate(attns): pad2 = mel.shape[1] - attn.shape[1] pad1 = text.shape[1] - attn.shape[0] @@ -502,18 +514,18 @@ class TTSDataset(Dataset): # TODO: return dictionary return { "text": text, - "text_lengths": text_lenghts, - "speaker_names": speaker_names, + "text_lengths": text_lengths, + "speaker_names": batch["speaker_name"], "linear": linear, "mel": mel, "mel_lengths": mel_lengths, "stop_targets": stop_targets, - "item_idxs": item_idxs, + "item_idxs": batch["item_idx"], "d_vectors": d_vectors, "speaker_ids": speaker_ids, "attns": attns, "waveform": wav_padded, - "raw_text": raw_text, + "raw_text": batch["raw_text"], "pitch": pitch, } @@ -567,13 +579,20 @@ class PitchExtractor: def normalize_pitch(self, pitch): zero_idxs = np.where(pitch == 0.0)[0] - pitch -= self.mean - pitch /= self.std + pitch = pitch - self.mean + pitch = pitch / self.std + pitch[zero_idxs] = 0.0 + return pitch + + def denormalize_pitch(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch *= self.std + pitch += self.mean pitch[zero_idxs] = 0.0 return pitch @staticmethod - def _load_or_compute_pitch(ap, wav_file, cache_path): + def load_or_compute_pitch(ap, wav_file, cache_path): """ compute pitch and return a numpy array of pitch values """ @@ -582,7 +601,7 @@ class PitchExtractor: pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) else: pitch = np.load(pitch_file) - return pitch + return pitch.astype(np.float32) @staticmethod def _pitch_worker(args): @@ -596,7 +615,7 @@ class PitchExtractor: return pitch return None - def compute_pitch(self, cache_path, num_workers=0): + def compute_pitch(self, ap, cache_path, num_workers=0): """Compute the input sequences with multi-processing. Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" if not os.path.exists(cache_path): @@ -607,12 +626,12 @@ class PitchExtractor: if num_workers == 0: pitch_vecs = [] for _, item in enumerate(tqdm.tqdm(self.items)): - pitch_vecs += [self._pitch_worker([item, self.ap, cache_path])] + pitch_vecs += [self._pitch_worker([item, ap, cache_path])] else: with Pool(num_workers) as p: pitch_vecs = list( tqdm.tqdm( - p.imap(PitchExtractor._pitch_worker, [[item, self.ap, cache_path] for item in self.items]), + p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]), total=len(self.items), ) ) @@ -623,5 +642,5 @@ class PitchExtractor: def load_pitch_stats(self, cache_path): stats_path = os.path.join(cache_path, "pitch_stats.npy") stats = np.load(stats_path, allow_pickle=True).item() - self.mean = stats["mean"] - self.std = stats["std"] + self.mean = stats["mean"].astype(np.float32) + self.std = stats["std"].astype(np.float32)