diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 337dcfa5..d4a12c07 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -38,7 +38,7 @@ class TTSDataset(Dataset): outputs_per_step: int, compute_linear_spec: bool, ap: AudioProcessor, - meta_data: List[Dict], + samples: List[Dict], tokenizer: "TTSTokenizer" = None, compute_f0: bool = False, f0_cache_path: str = None, @@ -67,7 +67,7 @@ class TTSDataset(Dataset): ap (TTS.tts.utils.AudioProcessor): Audio processor object. - meta_data (list): List of dataset samples. + samples (list): List of dataset samples. tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else use the given. Defaults to None. @@ -111,7 +111,7 @@ class TTSDataset(Dataset): """ super().__init__() self.batch_group_size = batch_group_size - self._samples = meta_data + self._samples = samples self.outputs_per_step = outputs_per_step self.sample_rate = ap.sample_rate self.compute_linear_spec = compute_linear_spec @@ -200,7 +200,7 @@ class TTSDataset(Dataset): token_ids = self.get_phonemes(idx, text)["token_ids"] else: token_ids = self.tokenizer.text_to_ids(text) - return token_ids + return np.array(token_ids, dtype=np.int32) def load_data(self, idx): item = self.samples[idx] @@ -258,7 +258,7 @@ class TTSDataset(Dataset): return audio_lengths, text_lengths @staticmethod - def sort_and_filter_by_length(lengths: List[int], min_len: int, max_len: int): + def filter_by_length(lengths: List[int], min_len: int, max_len: int): idxs = np.argsort(lengths) # ascending order ignore_idx = [] keep_idx = [] @@ -270,6 +270,11 @@ class TTSDataset(Dataset): keep_idx.append(idx) return ignore_idx, keep_idx + @staticmethod + def sort_by_length(lengths: List[int]): + idxs = np.argsort(lengths) # ascending order + return idxs + @staticmethod def create_buckets(samples, batch_group_size: int): for i in range(len(samples) // batch_group_size): @@ -280,24 +285,33 @@ class TTSDataset(Dataset): samples[offset:end_offset] = temp_items return samples + def select_samples_by_idx(self, idxs): + samples = [] + audio_lengths = [] + text_lengths = [] + for idx in idxs: + samples.append(self.samples[idx]) + audio_lengths.append(self.audio_lengths[idx]) + text_lengths.append(self.text_lengths[idx]) + return samples, audio_lengths, text_lengths + def preprocess_samples(self): r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length range. """ # sort items based on the sequence length in ascending order - text_ignore_idx, text_keep_idx = self.sort_and_filter_by_length( - self.text_lengths, self.min_text_len, self.max_text_len - ) - audio_ignore_idx, audio_keep_idx = self.sort_and_filter_by_length( + text_ignore_idx, text_keep_idx = self.filter_by_length(self.text_lengths, self.min_text_len, self.max_text_len) + audio_ignore_idx, audio_keep_idx = self.filter_by_length( self.audio_lengths, self.min_audio_len, self.max_audio_len ) keep_idx = list(set(audio_keep_idx) | set(text_keep_idx)) ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) - samples = [] - for idx in keep_idx: - samples.append(self.samples[idx]) + samples, audio_lengths, _ = self.select_samples_by_idx(keep_idx) + + sorted_idxs = self.sort_by_length(audio_lengths) + samples, audio_lengths, text_lengtsh = self.select_samples_by_idx(sorted_idxs) if len(samples) == 0: raise RuntimeError(" [!] No samples left") @@ -309,6 +323,8 @@ class TTSDataset(Dataset): # update items to the new sorted items self.samples = samples + self.audio_lengths = audio_lengths + self.text_lengths = text_lengtsh if self.verbose: print(" | > Preprocessing samples") @@ -391,7 +407,7 @@ class TTSDataset(Dataset): stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) # PAD sequences with longest instance in the batch - text = prepare_data(batch["token_ids"]).astype(np.int32) + token_ids = prepare_data(batch["token_ids"]).astype(np.int32) # PAD features with longest instance mel = prepare_tensor(mel, self.outputs_per_step) @@ -401,7 +417,7 @@ class TTSDataset(Dataset): # convert things to pytorch token_ids_lengths = torch.LongTensor(token_ids_lengths) - text = torch.LongTensor(text) + token_ids = torch.LongTensor(token_ids) mel = torch.FloatTensor(mel).contiguous() mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) @@ -453,7 +469,7 @@ class TTSDataset(Dataset): attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] for idx, attn in enumerate(attns): pad2 = mel.shape[1] - attn.shape[1] - pad1 = text.shape[1] - attn.shape[0] + pad1 = token_ids.shape[1] - attn.shape[0] assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}" attn = np.pad(attn, [[0, pad1], [0, pad2]]) attns[idx] = attn @@ -461,7 +477,7 @@ class TTSDataset(Dataset): attns = torch.FloatTensor(attns).unsqueeze(1) return { - "token_id": text, + "token_id": token_ids, "token_id_lengths": token_ids_lengths, "speaker_names": batch["speaker_name"], "linear": linear, @@ -786,7 +802,7 @@ if __name__ == "__main__": dataset = TTSDataset( outputs_per_step=1, compute_linear_spec=False, - meta_data=samples, + samples=samples, ap=ap, return_wav=False, batch_group_size=0, diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index fac430f0..68a1c575 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -147,6 +147,7 @@ class TTSTokenizer: if isinstance(config.text_cleaner, (str, list)): text_cleaner = getattr(cleaners, config.text_cleaner) + phonemizer = None if config.use_phonemes: # init phoneme set characters = IPAPhonemes().init_from_config(config) diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index d210995d..712e59e3 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -7,9 +7,9 @@ import torch from torch.utils.data import DataLoader from tests import get_tests_output_path -from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.configs.shared_configs import BaseTTSConfig, BaseDatasetConfig from TTS.tts.datasets import TTSDataset, load_tts_samples -from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor # pylint: disable=unused-variable @@ -50,18 +50,19 @@ class TestTTSDataset(unittest.TestCase): meta_data_train, meta_data_eval = load_tts_samples(dataset_config, eval_split=True, eval_split_size=0.2) items = meta_data_train + meta_data_eval + tokenizer = TTSTokenizer.init_from_config(c) dataset = TTSDataset( - r, - c.text_cleaner, + outputs_per_step=r, compute_linear_spec=True, return_wav=True, + tokenizer=tokenizer, ap=self.ap, - meta_data=items, - characters=c.characters, + samples=items, batch_group_size=bgs, - min_seq_len=c.min_seq_len, - max_seq_len=float("inf"), - use_phonemes=False, + min_text_len=c.min_text_len, + max_text_len=c.max_text_len, + min_audio_len=c.min_audio_len, + max_audio_len=c.max_audio_len, ) dataloader = DataLoader( dataset, @@ -80,27 +81,26 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data["text"] - text_lengths = data["text_lengths"] + text_input = data["token_id"] + _ = data["token_id_lengths"] speaker_name = data["speaker_names"] linear_input = data["linear"] mel_input = data["mel"] mel_lengths = data["mel_lengths"] - stop_target = data["stop_targets"] - item_idx = data["item_idxs"] + _ = data["stop_targets"] + _ = data["item_idxs"] wavs = data["waveform"] neg_values = text_input[text_input < 0] check_count = len(neg_values) - assert check_count == 0, " !! Negative values in text_input: {}".format(check_count) - assert isinstance(speaker_name[0], str) - assert linear_input.shape[0] == c.batch_size - assert linear_input.shape[2] == self.ap.fft_size // 2 + 1 - assert mel_input.shape[0] == c.batch_size - assert mel_input.shape[2] == c.audio["num_mels"] - assert ( - wavs.shape[1] == mel_input.shape[1] * c.audio.hop_length - ), f"wavs.shape: {wavs.shape[1]}, mel_input.shape: {mel_input.shape[1] * c.audio.hop_length}" + + # check basic conditions + self.assertEqual(check_count, 0) + self.assertEqual(linear_input.shape[0], mel_input.shape[0], c.batch_size) + self.assertEqual(linear_input.shape[2], self.ap.fft_size // 2 + 1) + self.assertEqual(mel_input.shape[2], c.audio["num_mels"]) + self.assertEqual(wavs.shape[1], mel_input.shape[1] * c.audio.hop_length) + self.assertIsInstance(speaker_name[0], str) # make sure that the computed mels and the waveform match and correctly computed mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy()) @@ -109,55 +109,58 @@ class TestTTSDataset(unittest.TestCase): # guarantee that both mel-spectrograms have the same size and that we will remove waveform padding mel_new = mel_new[:, :mel_lengths[0]] ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length) - mel_diff = (mel_new - mel_dataloader)[:, 0:ignore_seg] - assert abs(mel_diff.sum()) < 1e-5 + mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg] + self.assertLess(abs(mel_diff.sum()), 1e-5) # check normalization ranges if self.ap.symmetric_norm: - assert mel_input.max() <= self.ap.max_norm - assert mel_input.min() >= -self.ap.max_norm # pylint: disable=invalid-unary-operand-type - assert mel_input.min() < 0 + self.assertLessEqual(mel_input.max(), self.ap.max_norm) + self.assertGreaterEqual( + mel_input.min(), -self.ap.max_norm + ) # pylint: disable=invalid-unary-operand-type + self.assertLess(mel_input.min(), 0) else: - assert mel_input.max() <= self.ap.max_norm - assert mel_input.min() >= 0 + self.assertLessEqual(mel_input.max(), self.ap.max_norm) + self.assertGreaterEqual(mel_input.min(), 0) def test_batch_group_shuffle(self): if ok_ljspeech: dataloader, dataset = self._create_dataloader(2, c.r, 16) last_length = 0 - frames = dataset.items + frames = dataset.samples for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data["text"] - text_lengths = data["text_lengths"] - speaker_name = data["speaker_names"] - linear_input = data["linear"] - mel_input = data["mel"] mel_lengths = data["mel_lengths"] - stop_target = data["stop_targets"] - item_idx = data["item_idxs"] - avg_length = mel_lengths.numpy().mean() - assert avg_length >= last_length - dataloader.dataset.sort_and_filter_items() + dataloader.dataset.preprocess_samples() is_items_reordered = False - for idx, item in enumerate(dataloader.dataset.items): + for idx, item in enumerate(dataloader.dataset.samples): if item != frames[idx]: is_items_reordered = True break - assert is_items_reordered + self.assertGreaterEqual(avg_length, last_length) + self.assertTrue(is_items_reordered) + + def test_padding_and_spectrograms(self): + def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths): + self.assertNotEqual(linear_input[idx, -1].sum(), 0) # check padding + self.assertNotEqual(linear_input[idx, -2].sum(), 0) + self.assertNotEqual(mel_input[idx, -1].sum(), 0) + self.assertNotEqual(mel_input[idx, -2].sum(), 0) + self.assertEqual(stop_target[idx, -1], 1) + self.assertEqual(stop_target[idx, -2], 0) + self.assertEqual(stop_target[idx].sum(), 1) + self.assertEqual(len(mel_lengths.shape), 1) + self.assertEqual(mel_lengths[idx], linear_input[idx].shape[0]) + self.assertEqual(mel_lengths[idx], mel_input[idx].shape[0]) - def test_padding_and_spec(self): if ok_ljspeech: - dataloader, dataset = self._create_dataloader(1, 1, 0) + dataloader, _ = self._create_dataloader(1, 1, 0) for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data["text"] - text_lengths = data["text_lengths"] - speaker_name = data["speaker_names"] linear_input = data["linear"] mel_input = data["mel"] mel_lengths = data["mel_lengths"] @@ -172,7 +175,7 @@ class TestTTSDataset(unittest.TestCase): # NOTE: Below needs to check == 0 but due to an unknown reason # there is a slight difference between two matrices. # TODO: Check this assert cond more in detail. - assert abs(mel.T - mel_dl).max() < 1e-5, abs(mel.T - mel_dl).max() + self.assertLess(abs(mel.T - mel_dl).max(), 1e-5) # check mel-spec correctness mel_spec = mel_input[0].cpu().numpy() @@ -186,56 +189,36 @@ class TestTTSDataset(unittest.TestCase): self.ap.save_wav(wav, OUTPATH + "/linear_inv_dataloader.wav") shutil.copy(item_idx[0], OUTPATH + "/linear_target_dataloader.wav") - # check the last time step to be zero padded - assert linear_input[0, -1].sum() != 0 - assert linear_input[0, -2].sum() != 0 - assert mel_input[0, -1].sum() != 0 - assert mel_input[0, -2].sum() != 0 - assert stop_target[0, -1] == 1 - assert stop_target[0, -2] == 0 - assert stop_target.sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[0] == linear_input[0].shape[0] - assert mel_lengths[0] == mel_input[0].shape[0] + # check the outputs + check_conditions(0, linear_input, mel_input, stop_target, mel_lengths) # Test for batch size 2 - dataloader, dataset = self._create_dataloader(2, 1, 0) + dataloader, _ = self._create_dataloader(2, 1, 0) for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data["text"] - text_lengths = data["text_lengths"] - speaker_name = data["speaker_names"] linear_input = data["linear"] mel_input = data["mel"] mel_lengths = data["mel_lengths"] stop_target = data["stop_targets"] item_idx = data["item_idxs"] + # set id to the longest sequence in the batch if mel_lengths[0] > mel_lengths[1]: idx = 0 else: idx = 1 - # check the first item in the batch - assert linear_input[idx, -1].sum() != 0 - assert linear_input[idx, -2].sum() != 0, linear_input - assert mel_input[idx, -1].sum() != 0 - assert mel_input[idx, -2].sum() != 0, mel_input - assert stop_target[idx, -1] == 1 - assert stop_target[idx, -2] == 0 - assert stop_target[idx].sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[idx] == mel_input[idx].shape[0] - assert mel_lengths[idx] == linear_input[idx].shape[0] + # check the longer item in the batch + check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths) - # check the second itme in the batch - assert linear_input[1 - idx, -1].sum() == 0 - assert mel_input[1 - idx, -1].sum() == 0 - assert stop_target[1, mel_lengths[1] - 1] == 1 - assert stop_target[1, mel_lengths[1] :].sum() == stop_target.shape[1] - mel_lengths[1] - assert len(mel_lengths.shape) == 1 + # check the other item in the batch + self.assertEqual(linear_input[1 - idx, -1].sum(), 0) + self.assertEqual(mel_input[1 - idx, -1].sum(), 0) + self.assertEqual(stop_target[1, mel_lengths[1] - 1], 1) + self.assertEqual(stop_target[1, mel_lengths[1] :].sum(), stop_target.shape[1] - mel_lengths[1]) + self.assertEqual(len(mel_lengths.shape), 1) # check batch zero-frame conditions (zero-frame disabled) # assert (linear_input * stop_target.unsqueeze(2)).sum() == 0