From ae5e8b2b18d50323351737fe5d1d7581b5e97fc4 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 17 Dec 2018 16:35:52 +0100 Subject: [PATCH] update loader_tests.py --- tests/loader_tests.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/loader_tests.py b/tests/loader_tests.py index c945a592..7fc003a1 100644 --- a/tests/loader_tests.py +++ b/tests/loader_tests.py @@ -23,6 +23,9 @@ if not os.path.exists(c.data_path_cache): if not os.path.exists(c.data_path): DATA_EXIST = False +print(" > Dynamic data loader test: {}".format(DATA_EXIST)) +print(" > Cache data loader test: {}".format(CACHE_EXIST)) + class TestTTSDataset(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestTTSDataset, self).__init__(*args, **kwargs) @@ -199,7 +202,7 @@ class TestTTSDatasetCached(unittest.TestCase): def _create_dataloader(self, batch_size, r, bgs): - dataset = TTSDatasetCached.MyDataset( + dataset = TTSDataset.MyDataset( c.data_path_cache, 'tts_metadata.csv', r, @@ -207,7 +210,9 @@ class TestTTSDatasetCached(unittest.TestCase): preprocessor=tts_cache, ap=self.ap, batch_group_size=bgs, - min_seq_len=c.min_seq_len) + min_seq_len=c.min_seq_len, + max_seq_len=c.max_seq_len, + cached=True) dataloader = DataLoader( dataset, @@ -299,11 +304,17 @@ class TestTTSDatasetCached(unittest.TestCase): abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() # check mel-spec correctness - mel_spec = mel_input[0].cpu().numpy() + mel_spec = mel_input[-1].cpu().numpy() wav = self.ap.inv_mel_spectrogram(mel_spec.T) self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader_cache.wav') - shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader_cache.wav') + shutil.copy(item_idx[-1], OUTPATH + '/mel_target_dataloader_cache.wav') + + # check linear-spec + linear_spec = linear_input[-1].cpu().numpy() + wav = self.ap.inv_spectrogram(linear_spec.T) + self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader_cache.wav') + shutil.copy(item_idx[-1], OUTPATH + '/linear_target_dataloader_cache.wav') # check the last time step to be zero padded assert mel_input[0, -1].sum() == 0