mirror of https://github.com/coqui-ai/TTS.git
Update loader tests for dict return
This commit is contained in:
parent
2c4bbbf9b9
commit
fd287aa438
|
@ -68,15 +68,15 @@ class TestTTSDataset(unittest.TestCase):
|
|||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
speaker_name = data[2]
|
||||
linear_input = data[3]
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
stop_target = data[6]
|
||||
item_idx = data[7]
|
||||
wavs = data[11]
|
||||
text_input = data['text']
|
||||
text_lengths = data['text_lengths']
|
||||
speaker_name = data['speaker_names']
|
||||
linear_input = data['linear']
|
||||
mel_input = data['mel']
|
||||
mel_lengths = data['mel_lengths']
|
||||
stop_target = data['stop_targets']
|
||||
item_idx = data['item_idxs']
|
||||
wavs = data['waveform']
|
||||
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
|
@ -113,14 +113,14 @@ class TestTTSDataset(unittest.TestCase):
|
|||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
speaker_name = data[2]
|
||||
linear_input = data[3]
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
stop_target = data[6]
|
||||
item_idx = data[7]
|
||||
text_input = data['text']
|
||||
text_lengths = data['text_lengths']
|
||||
speaker_name = data['speaker_names']
|
||||
linear_input = data['linear']
|
||||
mel_input = data['mel']
|
||||
mel_lengths = data['mel_lengths']
|
||||
stop_target = data['stop_targets']
|
||||
item_idx = data['item_idxs']
|
||||
|
||||
avg_length = mel_lengths.numpy().mean()
|
||||
assert avg_length >= last_length
|
||||
|
@ -139,14 +139,14 @@ class TestTTSDataset(unittest.TestCase):
|
|||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
speaker_name = data[2]
|
||||
linear_input = data[3]
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
stop_target = data[6]
|
||||
item_idx = data[7]
|
||||
text_input = data['text']
|
||||
text_lengths = data['text_lengths']
|
||||
speaker_name = data['speaker_names']
|
||||
linear_input = data['linear']
|
||||
mel_input = data['mel']
|
||||
mel_lengths = data['mel_lengths']
|
||||
stop_target = data['stop_targets']
|
||||
item_idx = data['item_idxs']
|
||||
|
||||
# check mel_spec consistency
|
||||
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
|
||||
|
@ -188,14 +188,14 @@ class TestTTSDataset(unittest.TestCase):
|
|||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
speaker_name = data[2]
|
||||
linear_input = data[3]
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
stop_target = data[6]
|
||||
item_idx = data[7]
|
||||
text_input = data['text']
|
||||
text_lengths = data['text_lengths']
|
||||
speaker_name = data['speaker_names']
|
||||
linear_input = data['linear']
|
||||
mel_input = data['mel']
|
||||
mel_lengths = data['mel_lengths']
|
||||
stop_target = data['stop_targets']
|
||||
item_idx = data['item_idxs']
|
||||
|
||||
if mel_lengths[0] > mel_lengths[1]:
|
||||
idx = 0
|
||||
|
|
Loading…
Reference in New Issue