mirror of https://github.com/coqui-ai/TTS.git
update loader_tests.py
This commit is contained in:
parent
be6e46798b
commit
ae5e8b2b18
|
@ -23,6 +23,9 @@ if not os.path.exists(c.data_path_cache):
|
||||||
if not os.path.exists(c.data_path):
|
if not os.path.exists(c.data_path):
|
||||||
DATA_EXIST = False
|
DATA_EXIST = False
|
||||||
|
|
||||||
|
print(" > Dynamic data loader test: {}".format(DATA_EXIST))
|
||||||
|
print(" > Cache data loader test: {}".format(CACHE_EXIST))
|
||||||
|
|
||||||
class TestTTSDataset(unittest.TestCase):
|
class TestTTSDataset(unittest.TestCase):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(TestTTSDataset, self).__init__(*args, **kwargs)
|
super(TestTTSDataset, self).__init__(*args, **kwargs)
|
||||||
|
@ -199,7 +202,7 @@ class TestTTSDatasetCached(unittest.TestCase):
|
||||||
|
|
||||||
def _create_dataloader(self, batch_size, r, bgs):
|
def _create_dataloader(self, batch_size, r, bgs):
|
||||||
|
|
||||||
dataset = TTSDatasetCached.MyDataset(
|
dataset = TTSDataset.MyDataset(
|
||||||
c.data_path_cache,
|
c.data_path_cache,
|
||||||
'tts_metadata.csv',
|
'tts_metadata.csv',
|
||||||
r,
|
r,
|
||||||
|
@ -207,7 +210,9 @@ class TestTTSDatasetCached(unittest.TestCase):
|
||||||
preprocessor=tts_cache,
|
preprocessor=tts_cache,
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
batch_group_size=bgs,
|
batch_group_size=bgs,
|
||||||
min_seq_len=c.min_seq_len)
|
min_seq_len=c.min_seq_len,
|
||||||
|
max_seq_len=c.max_seq_len,
|
||||||
|
cached=True)
|
||||||
|
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -299,11 +304,17 @@ class TestTTSDatasetCached(unittest.TestCase):
|
||||||
abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum()
|
abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum()
|
||||||
|
|
||||||
# check mel-spec correctness
|
# check mel-spec correctness
|
||||||
mel_spec = mel_input[0].cpu().numpy()
|
mel_spec = mel_input[-1].cpu().numpy()
|
||||||
wav = self.ap.inv_mel_spectrogram(mel_spec.T)
|
wav = self.ap.inv_mel_spectrogram(mel_spec.T)
|
||||||
self.ap.save_wav(wav,
|
self.ap.save_wav(wav,
|
||||||
OUTPATH + '/mel_inv_dataloader_cache.wav')
|
OUTPATH + '/mel_inv_dataloader_cache.wav')
|
||||||
shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader_cache.wav')
|
shutil.copy(item_idx[-1], OUTPATH + '/mel_target_dataloader_cache.wav')
|
||||||
|
|
||||||
|
# check linear-spec
|
||||||
|
linear_spec = linear_input[-1].cpu().numpy()
|
||||||
|
wav = self.ap.inv_spectrogram(linear_spec.T)
|
||||||
|
self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader_cache.wav')
|
||||||
|
shutil.copy(item_idx[-1], OUTPATH + '/linear_target_dataloader_cache.wav')
|
||||||
|
|
||||||
# check the last time step to be zero padded
|
# check the last time step to be zero padded
|
||||||
assert mel_input[0, -1].sum() == 0
|
assert mel_input[0, -1].sum() == 0
|
||||||
|
|
Loading…
Reference in New Issue