From fd287aa4389d4a3e14f01e924914f2a3d4bc0208 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 6 Sep 2021 14:29:45 +0000 Subject: [PATCH] Update loader tests for dict return --- tests/data_tests/test_loader.py | 66 ++++++++++++++++----------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 717b2e0f..0fbb6bde 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -68,15 +68,15 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data[0] - text_lengths = data[1] - speaker_name = data[2] - linear_input = data[3] - mel_input = data[4] - mel_lengths = data[5] - stop_target = data[6] - item_idx = data[7] - wavs = data[11] + text_input = data['text'] + text_lengths = data['text_lengths'] + speaker_name = data['speaker_names'] + linear_input = data['linear'] + mel_input = data['mel'] + mel_lengths = data['mel_lengths'] + stop_target = data['stop_targets'] + item_idx = data['item_idxs'] + wavs = data['waveform'] neg_values = text_input[text_input < 0] check_count = len(neg_values) @@ -113,14 +113,14 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data[0] - text_lengths = data[1] - speaker_name = data[2] - linear_input = data[3] - mel_input = data[4] - mel_lengths = data[5] - stop_target = data[6] - item_idx = data[7] + text_input = data['text'] + text_lengths = data['text_lengths'] + speaker_name = data['speaker_names'] + linear_input = data['linear'] + mel_input = data['mel'] + mel_lengths = data['mel_lengths'] + stop_target = data['stop_targets'] + item_idx = data['item_idxs'] avg_length = mel_lengths.numpy().mean() assert avg_length >= last_length @@ -139,14 +139,14 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data[0] - text_lengths = data[1] - speaker_name = data[2] - linear_input = data[3] - mel_input = data[4] - mel_lengths = data[5] - stop_target = data[6] - item_idx = data[7] + text_input = data['text'] + text_lengths = data['text_lengths'] + speaker_name = data['speaker_names'] + linear_input = data['linear'] + mel_input = data['mel'] + mel_lengths = data['mel_lengths'] + stop_target = data['stop_targets'] + item_idx = data['item_idxs'] # check mel_spec consistency wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32) @@ -188,14 +188,14 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data[0] - text_lengths = data[1] - speaker_name = data[2] - linear_input = data[3] - mel_input = data[4] - mel_lengths = data[5] - stop_target = data[6] - item_idx = data[7] + text_input = data['text'] + text_lengths = data['text_lengths'] + speaker_name = data['speaker_names'] + linear_input = data['linear'] + mel_input = data['mel'] + mel_lengths = data['mel_lengths'] + stop_target = data['stop_targets'] + item_idx = data['item_idxs'] if mel_lengths[0] > mel_lengths[1]: idx = 0