From 8d6e31346c8d48f278dbfc9a67ffa38ff6b92b61 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Thu, 10 May 2018 16:00:21 -0700 Subject: [PATCH] tests update --- tests/layers_tests.py | 2 +- tests/loader_tests.py | 240 ++++++++++++++++++++-------------------- tests/tacotron_tests.py | 13 ++- 3 files changed, 128 insertions(+), 127 deletions(-) diff --git a/tests/layers_tests.py b/tests/layers_tests.py index 2303d8b9..cc2ac048 100644 --- a/tests/layers_tests.py +++ b/tests/layers_tests.py @@ -73,7 +73,7 @@ class L1LossMaskedTests(unittest.TestCase): dummy_length = (T.ones(4) * 8).long() output = layer(dummy_input, dummy_target, dummy_length) assert output.item() == 1.0, "1.0 vs {}".format(output.data[0]) - + dummy_input = T.ones(4, 8, 128).float() dummy_target = T.zeros(4, 8, 128).float() dummy_length = (T.arange(5, 9)).long() diff --git a/tests/loader_tests.py b/tests/loader_tests.py index 76d82557..2aa30a4c 100644 --- a/tests/loader_tests.py +++ b/tests/loader_tests.py @@ -5,7 +5,7 @@ import numpy as np from torch.utils.data import DataLoader from TTS.utils.generic_utils import load_config from TTS.datasets.LJSpeech import LJSpeechDataset -from TTS.datasets.TWEB import TWEBDataset +# from TTS.datasets.TWEB import TWEBDataset file_path = os.path.dirname(os.path.realpath(__file__)) @@ -19,8 +19,8 @@ class TestLJSpeechDataset(unittest.TestCase): self.max_loader_iter = 4 def test_loader(self): - dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'), - os.path.join(c.data_path_LJSpeech, 'wavs'), + dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'), + os.path.join(c.data_path, 'wavs'), c.r, c.sample_rate, c.text_cleaner, @@ -59,8 +59,8 @@ class TestLJSpeechDataset(unittest.TestCase): assert mel_input.shape[2] == c.num_mels def test_padding(self): - dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'), - os.path.join(c.data_path_LJSpeech, 'wavs'), + dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'), + os.path.join(c.data_path, 'wavs'), 1, c.sample_rate, c.text_cleaner, @@ -144,134 +144,134 @@ class TestLJSpeechDataset(unittest.TestCase): assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 -class TestTWEBDataset(unittest.TestCase): +# class TestTWEBDataset(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(TestTWEBDataset, self).__init__(*args, **kwargs) - self.max_loader_iter = 4 +# def __init__(self, *args, **kwargs): +# super(TestTWEBDataset, self).__init__(*args, **kwargs) +# self.max_loader_iter = 4 - def test_loader(self): - dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'), - os.path.join(c.data_path_TWEB, 'wavs'), - c.r, - c.sample_rate, - c.text_cleaner, - c.num_mels, - c.min_level_db, - c.frame_shift_ms, - c.frame_length_ms, - c.preemphasis, - c.ref_level_db, - c.num_freq, - c.power - ) +# def test_loader(self): +# dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'), +# os.path.join(c.data_path_TWEB, 'wavs'), +# c.r, +# c.sample_rate, +# c.text_cleaner, +# c.num_mels, +# c.min_level_db, +# c.frame_shift_ms, +# c.frame_length_ms, +# c.preemphasis, +# c.ref_level_db, +# c.num_freq, +# c.power +# ) - dataloader = DataLoader(dataset, batch_size=2, - shuffle=True, collate_fn=dataset.collate_fn, - drop_last=True, num_workers=c.num_loader_workers) +# dataloader = DataLoader(dataset, batch_size=2, +# shuffle=True, collate_fn=dataset.collate_fn, +# drop_last=True, num_workers=c.num_loader_workers) - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] +# for i, data in enumerate(dataloader): +# if i == self.max_loader_iter: +# break +# text_input = data[0] +# text_lengths = data[1] +# linear_input = data[2] +# mel_input = data[3] +# mel_lengths = data[4] +# stop_target = data[5] +# item_idx = data[6] - neg_values = text_input[text_input < 0] - check_count = len(neg_values) - assert check_count == 0, \ - " !! Negative values in text_input: {}".format(check_count) - # TODO: more assertion here - assert linear_input.shape[0] == c.batch_size - assert mel_input.shape[0] == c.batch_size - assert mel_input.shape[2] == c.num_mels +# neg_values = text_input[text_input < 0] +# check_count = len(neg_values) +# assert check_count == 0, \ +# " !! Negative values in text_input: {}".format(check_count) +# # TODO: more assertion here +# assert linear_input.shape[0] == c.batch_size +# assert mel_input.shape[0] == c.batch_size +# assert mel_input.shape[2] == c.num_mels - def test_padding(self): - dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'), - os.path.join(c.data_path_TWEB, 'wavs'), - 1, - c.sample_rate, - c.text_cleaner, - c.num_mels, - c.min_level_db, - c.frame_shift_ms, - c.frame_length_ms, - c.preemphasis, - c.ref_level_db, - c.num_freq, - c.power - ) +# def test_padding(self): +# dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'), +# os.path.join(c.data_path_TWEB, 'wavs'), +# 1, +# c.sample_rate, +# c.text_cleaner, +# c.num_mels, +# c.min_level_db, +# c.frame_shift_ms, +# c.frame_length_ms, +# c.preemphasis, +# c.ref_level_db, +# c.num_freq, +# c.power +# ) - # Test for batch size 1 - dataloader = DataLoader(dataset, batch_size=1, - shuffle=False, collate_fn=dataset.collate_fn, - drop_last=False, num_workers=c.num_loader_workers) +# # Test for batch size 1 +# dataloader = DataLoader(dataset, batch_size=1, +# shuffle=False, collate_fn=dataset.collate_fn, +# drop_last=False, num_workers=c.num_loader_workers) - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break +# for i, data in enumerate(dataloader): +# if i == self.max_loader_iter: +# break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] +# text_input = data[0] +# text_lengths = data[1] +# linear_input = data[2] +# mel_input = data[3] +# mel_lengths = data[4] +# stop_target = data[5] +# item_idx = data[6] - # check the last time step to be zero padded - assert mel_input[0, -1].sum() == 0 - assert mel_input[0, -2].sum() != 0, "{} -- {}".format(item_idx, i) - assert linear_input[0, -1].sum() == 0 - assert linear_input[0, -2].sum() != 0 - assert stop_target[0, -1] == 1 - assert stop_target[0, -2] == 0 - assert stop_target.sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[0] == mel_input[0].shape[0] +# # check the last time step to be zero padded +# assert mel_input[0, -1].sum() == 0 +# assert mel_input[0, -2].sum() != 0, "{} -- {}".format(item_idx, i) +# assert linear_input[0, -1].sum() == 0 +# assert linear_input[0, -2].sum() != 0 +# assert stop_target[0, -1] == 1 +# assert stop_target[0, -2] == 0 +# assert stop_target.sum() == 1 +# assert len(mel_lengths.shape) == 1 +# assert mel_lengths[0] == mel_input[0].shape[0] - # Test for batch size 2 - dataloader = DataLoader(dataset, batch_size=2, - shuffle=False, collate_fn=dataset.collate_fn, - drop_last=False, num_workers=c.num_loader_workers) +# # Test for batch size 2 +# dataloader = DataLoader(dataset, batch_size=2, +# shuffle=False, collate_fn=dataset.collate_fn, +# drop_last=False, num_workers=c.num_loader_workers) - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] +# for i, data in enumerate(dataloader): +# if i == self.max_loader_iter: +# break +# text_input = data[0] +# text_lengths = data[1] +# linear_input = data[2] +# mel_input = data[3] +# mel_lengths = data[4] +# stop_target = data[5] +# item_idx = data[6] - if mel_lengths[0] > mel_lengths[1]: - idx = 0 - else: - idx = 1 +# if mel_lengths[0] > mel_lengths[1]: +# idx = 0 +# else: +# idx = 1 - # check the first item in the batch - assert mel_input[idx, -1].sum() == 0 - assert mel_input[idx, -2].sum() != 0, mel_input - assert linear_input[idx, -1].sum() == 0 - assert linear_input[idx, -2].sum() != 0 - assert stop_target[idx, -1] == 1 - assert stop_target[idx, -2] == 0 - assert stop_target[idx].sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[idx] == mel_input[idx].shape[0] +# # check the first item in the batch +# assert mel_input[idx, -1].sum() == 0 +# assert mel_input[idx, -2].sum() != 0, mel_input +# assert linear_input[idx, -1].sum() == 0 +# assert linear_input[idx, -2].sum() != 0 +# assert stop_target[idx, -1] == 1 +# assert stop_target[idx, -2] == 0 +# assert stop_target[idx].sum() == 1 +# assert len(mel_lengths.shape) == 1 +# assert mel_lengths[idx] == mel_input[idx].shape[0] - # check the second itme in the batch - assert mel_input[1-idx, -1].sum() == 0 - assert linear_input[1-idx, -1].sum() == 0 - assert stop_target[1-idx, -1] == 1 - assert len(mel_lengths.shape) == 1 +# # check the second itme in the batch +# assert mel_input[1-idx, -1].sum() == 0 +# assert linear_input[1-idx, -1].sum() == 0 +# assert stop_target[1-idx, -1] == 1 +# assert len(mel_lengths.shape) == 1 - # check batch conditions - assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 - assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 +# # check batch conditions +# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 +# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 diff --git a/tests/tacotron_tests.py b/tests/tacotron_tests.py index 90bad689..65dbf3cd 100644 --- a/tests/tacotron_tests.py +++ b/tests/tacotron_tests.py @@ -42,20 +42,21 @@ class TacotronTrainTest(unittest.TestCase): count += 1 optimizer = optim.Adam(model.parameters(), lr=c.lr) for i in range(5): - mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec) - assert stop_tokens.data.max() <= 1.0 - assert stop_tokens.data.min() >= 0.0 + mel_out, linear_out, align = model.forward(input, mel_spec) + # mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec) + # assert stop_tokens.data.max() <= 1.0 + # assert stop_tokens.data.min() >= 0.0 optimizer.zero_grad() loss = criterion(mel_out, mel_spec, mel_lengths) - stop_loss = criterion_st(stop_tokens, stop_targets) - loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss + # stop_loss = criterion_st(stop_tokens, stop_targets) + loss = loss + criterion(linear_out, linear_spec, mel_lengths) loss.backward() optimizer.step() # check parameter changes count = 0 for param, param_ref in zip(model.parameters(), model_ref.parameters()): # ignore pre-higway layer since it works conditional - if count not in [141, 59]: + if count not in [139, 59]: assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref) count += 1