From 73c3ab546de237f58e8d07cd4580b36bb71bb526 Mon Sep 17 00:00:00 2001 From: Eren Date: Mon, 13 Aug 2018 15:02:30 +0200 Subject: [PATCH] Testing update --- tests/layers_tests.py | 11 +++++++++-- tests/loader_tests.py | 2 ++ tests/test_config.json | 4 ++-- train.py | 2 +- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/layers_tests.py b/tests/layers_tests.py index 28cb5cf2..2458e00b 100644 --- a/tests/layers_tests.py +++ b/tests/layers_tests.py @@ -19,14 +19,21 @@ class PrenetTests(unittest.TestCase): class CBHGTests(unittest.TestCase): def test_in_out(self): - layer = CBHG(128, K=6, projections=[128, 128], num_highways=2) + layer = self.cbhg = CBHG( + 128, + K=8, + conv_bank_features=80, + conv_projections=[160, 128], + highway_features=80, + gru_features=80, + num_highways=4) dummy_input = T.rand(4, 8, 128) print(layer) output = layer(dummy_input) assert output.shape[0] == 4 assert output.shape[1] == 8 - assert output.shape[2] == 256 + assert output.shape[2] == 160 class DecoderTests(unittest.TestCase): diff --git a/tests/loader_tests.py b/tests/loader_tests.py index 695861bb..f5e6b9d5 100644 --- a/tests/loader_tests.py +++ b/tests/loader_tests.py @@ -9,6 +9,8 @@ from TTS.datasets import LJSpeech, Kusal file_path = os.path.dirname(os.path.realpath(__file__)) c = load_config(os.path.join(file_path, 'test_config.json')) +ok_kusal = os.path.exists(c.data_path_Kusal) +ok_ljspeech = os.path.exists(c.data_path_LJSpeech) class TestLJSpeechDataset(unittest.TestCase): diff --git a/tests/test_config.json b/tests/test_config.json index 41d91fab..af0d070d 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -29,8 +29,8 @@ "num_loader_workers": 4, "save_step": 200, - "data_path_LJSpeech": "C:/Users/erogol/Data/LJSpeech-1.1", - "data_path_Kusal": "C:/Users/erogol/Data/Kusal", + "data_path_LJSpeech": "/home/erogol/Data/LJSpeech-1.1", + "data_path_Kusal": "/home/erogol/Data/Kusal", "output_path": "result", "min_seq_len": 0, "log_dir": "/home/erogol/projects/TTS/logs/" diff --git a/train.py b/train.py index 7078acb0..098ac94c 100644 --- a/train.py +++ b/train.py @@ -24,7 +24,7 @@ from utils.audio import AudioProcessor torch.manual_seed(1) -torch.set_num_threads(4) +# torch.set_num_threads(4) use_cuda = torch.cuda.is_available()