mirror of https://github.com/coqui-ai/TTS.git
Update loader tests for dict return
This commit is contained in:
parent
2c4bbbf9b9
commit
fd287aa438
|
@ -68,15 +68,15 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data['text']
|
||||||
text_lengths = data[1]
|
text_lengths = data['text_lengths']
|
||||||
speaker_name = data[2]
|
speaker_name = data['speaker_names']
|
||||||
linear_input = data[3]
|
linear_input = data['linear']
|
||||||
mel_input = data[4]
|
mel_input = data['mel']
|
||||||
mel_lengths = data[5]
|
mel_lengths = data['mel_lengths']
|
||||||
stop_target = data[6]
|
stop_target = data['stop_targets']
|
||||||
item_idx = data[7]
|
item_idx = data['item_idxs']
|
||||||
wavs = data[11]
|
wavs = data['waveform']
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
neg_values = text_input[text_input < 0]
|
||||||
check_count = len(neg_values)
|
check_count = len(neg_values)
|
||||||
|
@ -113,14 +113,14 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data['text']
|
||||||
text_lengths = data[1]
|
text_lengths = data['text_lengths']
|
||||||
speaker_name = data[2]
|
speaker_name = data['speaker_names']
|
||||||
linear_input = data[3]
|
linear_input = data['linear']
|
||||||
mel_input = data[4]
|
mel_input = data['mel']
|
||||||
mel_lengths = data[5]
|
mel_lengths = data['mel_lengths']
|
||||||
stop_target = data[6]
|
stop_target = data['stop_targets']
|
||||||
item_idx = data[7]
|
item_idx = data['item_idxs']
|
||||||
|
|
||||||
avg_length = mel_lengths.numpy().mean()
|
avg_length = mel_lengths.numpy().mean()
|
||||||
assert avg_length >= last_length
|
assert avg_length >= last_length
|
||||||
|
@ -139,14 +139,14 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data['text']
|
||||||
text_lengths = data[1]
|
text_lengths = data['text_lengths']
|
||||||
speaker_name = data[2]
|
speaker_name = data['speaker_names']
|
||||||
linear_input = data[3]
|
linear_input = data['linear']
|
||||||
mel_input = data[4]
|
mel_input = data['mel']
|
||||||
mel_lengths = data[5]
|
mel_lengths = data['mel_lengths']
|
||||||
stop_target = data[6]
|
stop_target = data['stop_targets']
|
||||||
item_idx = data[7]
|
item_idx = data['item_idxs']
|
||||||
|
|
||||||
# check mel_spec consistency
|
# check mel_spec consistency
|
||||||
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
|
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
|
||||||
|
@ -188,14 +188,14 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data['text']
|
||||||
text_lengths = data[1]
|
text_lengths = data['text_lengths']
|
||||||
speaker_name = data[2]
|
speaker_name = data['speaker_names']
|
||||||
linear_input = data[3]
|
linear_input = data['linear']
|
||||||
mel_input = data[4]
|
mel_input = data['mel']
|
||||||
mel_lengths = data[5]
|
mel_lengths = data['mel_lengths']
|
||||||
stop_target = data[6]
|
stop_target = data['stop_targets']
|
||||||
item_idx = data[7]
|
item_idx = data['item_idxs']
|
||||||
|
|
||||||
if mel_lengths[0] > mel_lengths[1]:
|
if mel_lengths[0] > mel_lengths[1]:
|
||||||
idx = 0
|
idx = 0
|
||||||
|
|
Loading…
Reference in New Issue