diff --git a/tests/layers_tests.py b/tests/layers_tests.py index b0849a3a..9b5c3f73 100644 --- a/tests/layers_tests.py +++ b/tests/layers_tests.py @@ -2,7 +2,8 @@ import unittest import torch as T from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder -from TTS.layers.losses import L1LossMasked, _sequence_mask +from TTS.layers.losses import L1LossMasked +from TTS.utils.generic_utils import sequence_mask class PrenetTests(unittest.TestCase): @@ -79,7 +80,7 @@ class L1LossMaskedTests(unittest.TestCase): dummy_input = T.ones(4, 8, 128).float() dummy_target = T.zeros(4, 8, 128).float() dummy_length = (T.arange(5, 9)).long() - mask = ((_sequence_mask(dummy_length).float() - 1.0) + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert output.item() == 1.0, "1.0 vs {}".format(output.data[0]) diff --git a/tests/loader_tests.py b/tests/loader_tests.py index f80abbfd..927126b4 100644 --- a/tests/loader_tests.py +++ b/tests/loader_tests.py @@ -4,7 +4,8 @@ import numpy as np from torch.utils.data import DataLoader from TTS.utils.generic_utils import load_config -from TTS.datasets.LJSpeech import LJSpeechDataset +from TTS.utils.audio import AudioProcessor +from TTS.datasets import LJSpeech, Kusal file_path = os.path.dirname(os.path.realpath(__file__)) c = load_config(os.path.join(file_path, 'test_config.json')) @@ -15,21 +16,25 @@ class TestLJSpeechDataset(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestLJSpeechDataset, self).__init__(*args, **kwargs) self.max_loader_iter = 4 + self.ap = AudioProcessor(sample_rate=c.sample_rate, + num_mels=c.num_mels, + min_level_db=c.min_level_db, + frame_shift_ms=c.frame_shift_ms, + frame_length_ms=c.frame_length_ms, + ref_level_db=c.ref_level_db, + num_freq=c.num_freq, + power=c.power, + preemphasis=c.preemphasis, + min_mel_freq=c.min_mel_freq, + max_mel_freq=c.max_mel_freq) def test_loader(self): - dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'), - os.path.join(c.data_path_LJSpeech, 'wavs'), + dataset = LJSpeech.MyDataset(os.path.join(c.data_path_LJSpeech), + os.path.join(c.data_path_LJSpeech, 'metadata.csv'), c.r, - c.sample_rate, c.text_cleaner, - c.num_mels, - c.min_level_db, - c.frame_shift_ms, - c.frame_length_ms, - c.preemphasis, - c.ref_level_db, - c.num_freq, - c.power + ap = self.ap, + min_seq_len=c.min_seq_len ) dataloader = DataLoader(dataset, batch_size=2, @@ -57,19 +62,12 @@ class TestLJSpeechDataset(unittest.TestCase): assert mel_input.shape[2] == c.num_mels def test_padding(self): - dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'), - os.path.join(c.data_path_LJSpeech, 'wavs'), + dataset = LJSpeech.MyDataset(os.path.join(c.data_path_LJSpeech), + os.path.join(c.data_path_LJSpeech, 'metadata.csv'), 1, - c.sample_rate, c.text_cleaner, - c.num_mels, - c.min_level_db, - c.frame_shift_ms, - c.frame_length_ms, - c.preemphasis, - c.ref_level_db, - c.num_freq, - c.power + ap = self.ap, + min_seq_len=c.min_seq_len ) # Test for batch size 1 @@ -141,6 +139,135 @@ class TestLJSpeechDataset(unittest.TestCase): assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 + +class TestKusalDataset(unittest.TestCase): + + def __init__(self, *args, **kwargs): + super(TestKusalDataset, self).__init__(*args, **kwargs) + self.max_loader_iter = 4 + self.ap = AudioProcessor(sample_rate=c.sample_rate, + num_mels=c.num_mels, + min_level_db=c.min_level_db, + frame_shift_ms=c.frame_shift_ms, + frame_length_ms=c.frame_length_ms, + ref_level_db=c.ref_level_db, + num_freq=c.num_freq, + power=c.power, + preemphasis=c.preemphasis, + min_mel_freq=c.min_mel_freq, + max_mel_freq=c.max_mel_freq) + + def test_loader(self): + dataset = Kusal.MyDataset(os.path.join(c.data_path_Kusal), + os.path.join(c.data_path_Kusal, 'prompts.txt'), + c.r, + c.text_cleaner, + ap = self.ap, + min_seq_len=c.min_seq_len + ) + + dataloader = DataLoader(dataset, batch_size=2, + shuffle=True, collate_fn=dataset.collate_fn, + drop_last=True, num_workers=c.num_loader_workers) + + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] + + neg_values = text_input[text_input < 0] + check_count = len(neg_values) + assert check_count == 0, \ + " !! Negative values in text_input: {}".format(check_count) + # TODO: more assertion here + assert linear_input.shape[0] == c.batch_size + assert mel_input.shape[0] == c.batch_size + assert mel_input.shape[2] == c.num_mels + + def test_padding(self): + dataset = Kusal.MyDataset(os.path.join(c.data_path_Kusal), + os.path.join(c.data_path_Kusal, 'prompts.txt'), + 1, + c.text_cleaner, + ap = self.ap, + min_seq_len=c.min_seq_len + ) + + # Test for batch size 1 + dataloader = DataLoader(dataset, batch_size=1, + shuffle=False, collate_fn=dataset.collate_fn, + drop_last=True, num_workers=c.num_loader_workers) + + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] + + # check the last time step to be zero padded + assert mel_input[0, -1].sum() == 0 + # assert mel_input[0, -2].sum() != 0 + assert linear_input[0, -1].sum() == 0 + # assert linear_input[0, -2].sum() != 0 + assert stop_target[0, -1] == 1 + assert stop_target[0, -2] == 0 + assert stop_target.sum() == 1 + assert len(mel_lengths.shape) == 1 + assert mel_lengths[0] == mel_input[0].shape[0] + + # Test for batch size 2 + dataloader = DataLoader(dataset, batch_size=2, + shuffle=False, collate_fn=dataset.collate_fn, + drop_last=False, num_workers=c.num_loader_workers) + + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] + + if mel_lengths[0] > mel_lengths[1]: + idx = 0 + else: + idx = 1 + + # check the first item in the batch + assert mel_input[idx, -1].sum() == 0 + assert mel_input[idx, -2].sum() != 0, mel_input + assert linear_input[idx, -1].sum() == 0 + assert linear_input[idx, -2].sum() != 0 + assert stop_target[idx, -1] == 1 + assert stop_target[idx, -2] == 0 + assert stop_target[idx].sum() == 1 + assert len(mel_lengths.shape) == 1 + assert mel_lengths[idx] == mel_input[idx].shape[0] + + # check the second itme in the batch + assert mel_input[1-idx, -1].sum() == 0 + assert linear_input[1-idx, -1].sum() == 0 + assert stop_target[1-idx, -1] == 1 + assert len(mel_lengths.shape) == 1 + + # check batch conditions + assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 + assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 + # class TestTWEBDataset(unittest.TestCase): diff --git a/tests/test_config.json b/tests/test_config.json index ec67521b..41d91fab 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -9,6 +9,8 @@ "ref_level_db": 20, "hidden_size": 128, "embedding_size": 256, + "min_mel_freq": null, + "max_mel_freq": null, "text_cleaner": "english_cleaners", "epochs": 2000, @@ -27,8 +29,9 @@ "num_loader_workers": 4, "save_step": 200, - "data_path_LJSpeech": "/data/shared/KeithIto/LJSpeech-1.0", - "data_path_TWEB": "/data/shared/BibleSpeech", + "data_path_LJSpeech": "C:/Users/erogol/Data/LJSpeech-1.1", + "data_path_Kusal": "C:/Users/erogol/Data/Kusal", "output_path": "result", + "min_seq_len": 0, "log_dir": "/home/erogol/projects/TTS/logs/" }