Update loader tests for dict return

This commit is contained in:
Eren Gölge 2021-09-06 14:29:45 +00:00
parent 2c4bbbf9b9
commit fd287aa438
1 changed files with 33 additions and 33 deletions

View File

@ -68,15 +68,15 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
wavs = data[11]
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
wavs = data['waveform']
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
@ -113,14 +113,14 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
avg_length = mel_lengths.numpy().mean()
assert avg_length >= last_length
@ -139,14 +139,14 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
# check mel_spec consistency
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
@ -188,14 +188,14 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
if mel_lengths[0] > mel_lengths[1]:
idx = 0