From dce1715e0f17084a64cb4b7b7e37ab95c3e341df Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 25 Feb 2019 18:34:06 +0100 Subject: [PATCH] tests updates --- datasets/TTSDataset.py | 2 +- tests/layers_tests.py | 2 +- tests/loader_tests.py | 370 +--------------------------------------- tests/tacotron_tests.py | 4 +- tests/test_config.json | 1 + 5 files changed, 10 insertions(+), 369 deletions(-) diff --git a/datasets/TTSDataset.py b/datasets/TTSDataset.py index 258d00cd..0eb5e115 100644 --- a/datasets/TTSDataset.py +++ b/datasets/TTSDataset.py @@ -61,7 +61,7 @@ class MyDataset(Dataset): self.use_phonemes = use_phonemes self.phoneme_cache_path = phoneme_cache_path self.phoneme_language = phoneme_language - if not os.path.isdir(phoneme_cache_path): + if use_phonemes and not os.path.isdir(phoneme_cache_path): os.makedirs(phoneme_cache_path) print(" > DataLoader initialization") print(" | > Data path: {}".format(root_path)) diff --git a/tests/layers_tests.py b/tests/layers_tests.py index be6a8516..5f769f9c 100644 --- a/tests/layers_tests.py +++ b/tests/layers_tests.py @@ -38,7 +38,7 @@ class CBHGTests(unittest.TestCase): class DecoderTests(unittest.TestCase): def test_in_out(self): - layer = Decoder(in_features=256, memory_dim=80, r=2) + layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False) dummy_input = T.rand(4, 8, 256) dummy_memory = T.rand(4, 2, 80) diff --git a/tests/loader_tests.py b/tests/loader_tests.py index 7fc003a1..a70cdfc3 100644 --- a/tests/loader_tests.py +++ b/tests/loader_tests.py @@ -6,7 +6,7 @@ import numpy as np from torch.utils.data import DataLoader from utils.generic_utils import load_config from utils.audio import AudioProcessor -from datasets import TTSDataset, TTSDatasetCached, TTSDatasetMemory +from datasets import TTSDataset from datasets.preprocess import ljspeech, tts_cache file_path = os.path.dirname(os.path.realpath(__file__)) @@ -41,7 +41,9 @@ class TestTTSDataset(unittest.TestCase): preprocessor=ljspeech, ap=self.ap, batch_group_size=bgs, - min_seq_len=c.min_seq_len) + min_seq_len=c.min_seq_len, + max_seq_len=float("inf"), + use_phonemes=False) dataloader = DataLoader( dataset, batch_size=batch_size, @@ -190,366 +192,4 @@ class TestTTSDataset(unittest.TestCase): # check batch conditions assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 - assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 - - -class TestTTSDatasetCached(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(TestTTSDatasetCached, self).__init__(*args, **kwargs) - self.max_loader_iter = 4 - self.c = load_config(os.path.join(c.data_path_cache, 'config.json')) - self.ap = AudioProcessor(**self.c.audio) - - def _create_dataloader(self, batch_size, r, bgs): - - dataset = TTSDataset.MyDataset( - c.data_path_cache, - 'tts_metadata.csv', - r, - c.text_cleaner, - preprocessor=tts_cache, - ap=self.ap, - batch_group_size=bgs, - min_seq_len=c.min_seq_len, - max_seq_len=c.max_seq_len, - cached=True) - - dataloader = DataLoader( - dataset, - batch_size=batch_size, - shuffle=False, - collate_fn=dataset.collate_fn, - drop_last=True, - num_workers=c.num_loader_workers) - return dataloader, dataset - - def test_loader(self): - if ok_ljspeech: - dataloader, dataset = self._create_dataloader(2, c.r, 0) - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] - - neg_values = text_input[text_input < 0] - check_count = len(neg_values) - assert check_count == 0, \ - " !! Negative values in text_input: {}".format(check_count) - # TODO: more assertion here - assert mel_input.shape[0] == c.batch_size - assert mel_input.shape[2] == c.audio['num_mels'] - - if self.ap.symmetric_norm: - assert mel_input.max() <= self.ap.max_norm - assert mel_input.min() >= -self.ap.max_norm - assert mel_input.min() < 0 - else: - assert mel_input.max() <= self.ap.max_norm - assert mel_input.min() >= 0 - - def test_batch_group_shuffle(self): - if ok_ljspeech: - dataloader, dataset = self._create_dataloader(2, c.r, 16) - frames = dataset.items - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] - - neg_values = text_input[text_input < 0] - check_count = len(neg_values) - assert check_count == 0, \ - " !! Negative values in text_input: {}".format(check_count) - # TODO: more assertion here - assert mel_input.shape[0] == c.batch_size - assert mel_input.shape[2] == c.audio['num_mels'] - dataloader.dataset.sort_items() - assert frames[0] != dataloader.dataset.items[0] - - def test_padding_and_spec(self): - if ok_ljspeech: - dataloader, dataset = self._create_dataloader(1, 1, 0) - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] - - # check mel_spec consistency - if item_idx[0].split('.')[-1] == 'npy': - wav = np.load(item_idx[0]) - else: - wav = self.ap.load_wav(item_idx[0]) - mel = self.ap.melspectrogram(wav) - mel_dl = mel_input[0].cpu().numpy() - assert (abs(mel.T).astype("float32") - abs( - mel_dl[:-1])).sum() == 0, ( - abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() - - # check mel-spec correctness - mel_spec = mel_input[-1].cpu().numpy() - wav = self.ap.inv_mel_spectrogram(mel_spec.T) - self.ap.save_wav(wav, - OUTPATH + '/mel_inv_dataloader_cache.wav') - shutil.copy(item_idx[-1], OUTPATH + '/mel_target_dataloader_cache.wav') - - # check linear-spec - linear_spec = linear_input[-1].cpu().numpy() - wav = self.ap.inv_spectrogram(linear_spec.T) - self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader_cache.wav') - shutil.copy(item_idx[-1], OUTPATH + '/linear_target_dataloader_cache.wav') - - # check the last time step to be zero padded - assert mel_input[0, -1].sum() == 0 - assert mel_input[0, -2].sum() != 0 - assert stop_target[0, -1] == 1 - assert stop_target[0, -2] == 0 - assert stop_target.sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[0] == mel_input[0].shape[0] - - # Test for batch size 2 - dataloader, dataset = self._create_dataloader(2, 1, 0) - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] - - if mel_lengths[0] > mel_lengths[1]: - idx = 0 - else: - idx = 1 - - # check the first item in the batch - assert mel_input[idx, -1].sum() == 0 - assert mel_input[idx, -2].sum() != 0, mel_input - assert stop_target[idx, -1] == 1 - assert stop_target[idx, -2] == 0 - assert stop_target[idx].sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[idx] == mel_input[idx].shape[0] - - # check the second itme in the batch - assert mel_input[1 - idx, -1].sum() == 0 - assert stop_target[1 - idx, -1] == 1 - assert len(mel_lengths.shape) == 1 - - # check batch conditions - assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 - - -# class TestTTSDatasetMemory(unittest.TestCase): -# def __init__(self, *args, **kwargs): -# super(TestTTSDatasetMemory, self).__init__(*args, **kwargs) -# self.max_loader_iter = 4 -# self.c = load_config(os.path.join(c.data_path_cache, 'config.json')) -# self.ap = AudioProcessor(**c.audio) - -# def test_loader(self): -# if ok_ljspeech: -# dataset = TTSDatasetMemory.MyDataset( -# c.data_path_cache, -# 'tts_metadata.csv', -# c.r, -# c.text_cleaner, -# preprocessor=tts_cache, -# ap=self.ap, -# min_seq_len=c.min_seq_len) - -# dataloader = DataLoader( -# dataset, -# batch_size=2, -# shuffle=True, -# collate_fn=dataset.collate_fn, -# drop_last=True, -# num_workers=c.num_loader_workers) - -# for i, data in enumerate(dataloader): -# if i == self.max_loader_iter: -# break -# text_input = data[0] -# text_lengths = data[1] -# linear_input = data[2] -# mel_input = data[3] -# mel_lengths = data[4] -# stop_target = data[5] -# item_idx = data[6] - -# neg_values = text_input[text_input < 0] -# check_count = len(neg_values) -# assert check_count == 0, \ -# " !! Negative values in text_input: {}".format(check_count) -# # check mel-spec shape -# assert mel_input.shape[0] == c.batch_size -# assert mel_input.shape[2] == c.audio['num_mels'] -# assert mel_input.max() <= self.ap.max_norm -# # check data range -# if self.ap.symmetric_norm: -# assert mel_input.max() <= self.ap.max_norm -# assert mel_input.min() >= -self.ap.max_norm -# assert mel_input.min() < 0 -# else: -# assert mel_input.max() <= self.ap.max_norm -# assert mel_input.min() >= 0 - -# def test_batch_group_shuffle(self): -# if ok_ljspeech: -# dataset = TTSDatasetMemory.MyDataset( -# c.data_path_cache, -# 'tts_metadata.csv', -# c.r, -# c.text_cleaner, -# preprocessor=ljspeech, -# ap=self.ap, -# batch_group_size=16, -# min_seq_len=c.min_seq_len) - -# dataloader = DataLoader( -# dataset, -# batch_size=2, -# shuffle=True, -# collate_fn=dataset.collate_fn, -# drop_last=True, -# num_workers=c.num_loader_workers) - -# frames = dataset.items -# for i, data in enumerate(dataloader): -# if i == self.max_loader_iter: -# break -# text_input = data[0] -# text_lengths = data[1] -# linear_input = data[2] -# mel_input = data[3] -# mel_lengths = data[4] -# stop_target = data[5] -# item_idx = data[6] - -# neg_values = text_input[text_input < 0] -# check_count = len(neg_values) -# assert check_count == 0, \ -# " !! Negative values in text_input: {}".format(check_count) -# assert mel_input.shape[0] == c.batch_size -# assert mel_input.shape[2] == c.audio['num_mels'] -# dataloader.dataset.sort_items() -# assert frames[0] != dataloader.dataset.items[0] - -# def test_padding_and_spec(self): -# if ok_ljspeech: -# dataset = TTSDatasetMemory.MyDataset( -# c.data_path_cache, -# 'tts_meta_data.csv', -# 1, -# c.text_cleaner, -# preprocessor=ljspeech, -# ap=self.ap, -# min_seq_len=c.min_seq_len) - -# # Test for batch size 1 -# dataloader = DataLoader( -# dataset, -# batch_size=1, -# shuffle=False, -# collate_fn=dataset.collate_fn, -# drop_last=True, -# num_workers=c.num_loader_workers) - -# for i, data in enumerate(dataloader): -# if i == self.max_loader_iter: -# break -# text_input = data[0] -# text_lengths = data[1] -# linear_input = data[2] -# mel_input = data[3] -# mel_lengths = data[4] -# stop_target = data[5] -# item_idx = data[6] - -# # check mel_spec consistency -# if item_idx[0].split('.')[-1] == 'npy': -# wav = np.load(item_idx[0]) -# else: -# wav = self.ap.load_wav(item_idx[0]) -# mel = self.ap.melspectrogram(wav) -# mel_dl = mel_input[0].cpu().numpy() -# assert ( -# abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0 - -# # check mel-spec correctness -# mel_spec = mel_input[0].cpu().numpy() -# wav = self.ap.inv_mel_spectrogram(mel_spec.T) -# self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader_memo.wav') -# shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader_memo.wav') - -# # check the last time step to be zero padded -# assert mel_input[0, -1].sum() == 0 -# assert mel_input[0, -2].sum() != 0 -# assert stop_target[0, -1] == 1 -# assert stop_target[0, -2] == 0 -# assert stop_target.sum() == 1 -# assert len(mel_lengths.shape) == 1 -# assert mel_lengths[0] == mel_input[0].shape[0] - -# # Test for batch size 2 -# dataloader = DataLoader( -# dataset, -# batch_size=2, -# shuffle=False, -# collate_fn=dataset.collate_fn, -# drop_last=False, -# num_workers=c.num_loader_workers) - -# for i, data in enumerate(dataloader): -# if i == self.max_loader_iter: -# break -# text_input = data[0] -# text_lengths = data[1] -# linear_input = data[2] -# mel_input = data[3] -# mel_lengths = data[4] -# stop_target = data[5] -# item_idx = data[6] - -# if mel_lengths[0] > mel_lengths[1]: -# idx = 0 -# else: -# idx = 1 - -# # check the first item in the batch -# assert mel_input[idx, -1].sum() == 0 -# assert mel_input[idx, -2].sum() != 0, mel_input -# assert stop_target[idx, -1] == 1 -# assert stop_target[idx, -2] == 0 -# assert stop_target[idx].sum() == 1 -# assert len(mel_lengths.shape) == 1 -# assert mel_lengths[idx] == mel_input[idx].shape[0] - -# # check the second itme in the batch -# assert mel_input[1 - idx, -1].sum() == 0 -# assert stop_target[1 - idx, -1] == 1 -# assert len(mel_lengths.shape) == 1 - -# # check batch conditions -# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 + assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 \ No newline at end of file diff --git a/tests/tacotron_tests.py b/tests/tacotron_tests.py index bcb702b2..866e1aa4 100644 --- a/tests/tacotron_tests.py +++ b/tests/tacotron_tests.py @@ -35,8 +35,8 @@ class TacotronTrainTest(unittest.TestCase): criterion = L1LossMasked().to(device) criterion_st = nn.BCELoss().to(device) - model = Tacotron(c.embedding_size, c.audio['num_freq'], c.audio['num_mels'], - c.r).to(device) + model = Tacotron(32, c.embedding_size, c.audio['num_freq'], c.audio['num_mels'], + c.r, c.memory_size).to(device) model.train() model_ref = copy.deepcopy(model) count = 0 diff --git a/tests/test_config.json b/tests/test_config.json index 7283664c..b4436572 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -32,6 +32,7 @@ "mk": 1.0, "priority_freq": false, "num_loader_workers": 4, + "memory_size": 5, "save_step": 200, "data_path": "/home/erogol/Data/LJSpeech-1.1/",