tests update

This commit is contained in:
Eren Golge 2018-05-10 16:00:21 -07:00
parent 14c9e9cde9
commit b087c0b5ec
3 changed files with 128 additions and 127 deletions

View File

@ -5,7 +5,7 @@ import numpy as np
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.utils.generic_utils import load_config from TTS.utils.generic_utils import load_config
from TTS.datasets.LJSpeech import LJSpeechDataset from TTS.datasets.LJSpeech import LJSpeechDataset
from TTS.datasets.TWEB import TWEBDataset # from TTS.datasets.TWEB import TWEBDataset
file_path = os.path.dirname(os.path.realpath(__file__)) file_path = os.path.dirname(os.path.realpath(__file__))
@ -19,8 +19,8 @@ class TestLJSpeechDataset(unittest.TestCase):
self.max_loader_iter = 4 self.max_loader_iter = 4
def test_loader(self): def test_loader(self):
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'), dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
os.path.join(c.data_path_LJSpeech, 'wavs'), os.path.join(c.data_path, 'wavs'),
c.r, c.r,
c.sample_rate, c.sample_rate,
c.text_cleaner, c.text_cleaner,
@ -59,8 +59,8 @@ class TestLJSpeechDataset(unittest.TestCase):
assert mel_input.shape[2] == c.num_mels assert mel_input.shape[2] == c.num_mels
def test_padding(self): def test_padding(self):
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'), dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
os.path.join(c.data_path_LJSpeech, 'wavs'), os.path.join(c.data_path, 'wavs'),
1, 1,
c.sample_rate, c.sample_rate,
c.text_cleaner, c.text_cleaner,
@ -144,134 +144,134 @@ class TestLJSpeechDataset(unittest.TestCase):
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
class TestTWEBDataset(unittest.TestCase): # class TestTWEBDataset(unittest.TestCase):
def __init__(self, *args, **kwargs): # def __init__(self, *args, **kwargs):
super(TestTWEBDataset, self).__init__(*args, **kwargs) # super(TestTWEBDataset, self).__init__(*args, **kwargs)
self.max_loader_iter = 4 # self.max_loader_iter = 4
def test_loader(self): # def test_loader(self):
dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'), # dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
os.path.join(c.data_path_TWEB, 'wavs'), # os.path.join(c.data_path_TWEB, 'wavs'),
c.r, # c.r,
c.sample_rate, # c.sample_rate,
c.text_cleaner, # c.text_cleaner,
c.num_mels, # c.num_mels,
c.min_level_db, # c.min_level_db,
c.frame_shift_ms, # c.frame_shift_ms,
c.frame_length_ms, # c.frame_length_ms,
c.preemphasis, # c.preemphasis,
c.ref_level_db, # c.ref_level_db,
c.num_freq, # c.num_freq,
c.power # c.power
) # )
dataloader = DataLoader(dataset, batch_size=2, # dataloader = DataLoader(dataset, batch_size=2,
shuffle=True, collate_fn=dataset.collate_fn, # shuffle=True, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=c.num_loader_workers) # drop_last=True, num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): # for i, data in enumerate(dataloader):
if i == self.max_loader_iter: # if i == self.max_loader_iter:
break # break
text_input = data[0] # text_input = data[0]
text_lengths = data[1] # text_lengths = data[1]
linear_input = data[2] # linear_input = data[2]
mel_input = data[3] # mel_input = data[3]
mel_lengths = data[4] # mel_lengths = data[4]
stop_target = data[5] # stop_target = data[5]
item_idx = data[6] # item_idx = data[6]
neg_values = text_input[text_input < 0] # neg_values = text_input[text_input < 0]
check_count = len(neg_values) # check_count = len(neg_values)
assert check_count == 0, \ # assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count) # " !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here # # TODO: more assertion here
assert linear_input.shape[0] == c.batch_size # assert linear_input.shape[0] == c.batch_size
assert mel_input.shape[0] == c.batch_size # assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels # assert mel_input.shape[2] == c.num_mels
def test_padding(self): # def test_padding(self):
dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'), # dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
os.path.join(c.data_path_TWEB, 'wavs'), # os.path.join(c.data_path_TWEB, 'wavs'),
1, # 1,
c.sample_rate, # c.sample_rate,
c.text_cleaner, # c.text_cleaner,
c.num_mels, # c.num_mels,
c.min_level_db, # c.min_level_db,
c.frame_shift_ms, # c.frame_shift_ms,
c.frame_length_ms, # c.frame_length_ms,
c.preemphasis, # c.preemphasis,
c.ref_level_db, # c.ref_level_db,
c.num_freq, # c.num_freq,
c.power # c.power
) # )
# Test for batch size 1 # # Test for batch size 1
dataloader = DataLoader(dataset, batch_size=1, # dataloader = DataLoader(dataset, batch_size=1,
shuffle=False, collate_fn=dataset.collate_fn, # shuffle=False, collate_fn=dataset.collate_fn,
drop_last=False, num_workers=c.num_loader_workers) # drop_last=False, num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): # for i, data in enumerate(dataloader):
if i == self.max_loader_iter: # if i == self.max_loader_iter:
break # break
text_input = data[0] # text_input = data[0]
text_lengths = data[1] # text_lengths = data[1]
linear_input = data[2] # linear_input = data[2]
mel_input = data[3] # mel_input = data[3]
mel_lengths = data[4] # mel_lengths = data[4]
stop_target = data[5] # stop_target = data[5]
item_idx = data[6] # item_idx = data[6]
# check the last time step to be zero padded # # check the last time step to be zero padded
assert mel_input[0, -1].sum() == 0 # assert mel_input[0, -1].sum() == 0
assert mel_input[0, -2].sum() != 0, "{} -- {}".format(item_idx, i) # assert mel_input[0, -2].sum() != 0, "{} -- {}".format(item_idx, i)
assert linear_input[0, -1].sum() == 0 # assert linear_input[0, -1].sum() == 0
assert linear_input[0, -2].sum() != 0 # assert linear_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1 # assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0 # assert stop_target[0, -2] == 0
assert stop_target.sum() == 1 # assert stop_target.sum() == 1
assert len(mel_lengths.shape) == 1 # assert len(mel_lengths.shape) == 1
assert mel_lengths[0] == mel_input[0].shape[0] # assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2 # # Test for batch size 2
dataloader = DataLoader(dataset, batch_size=2, # dataloader = DataLoader(dataset, batch_size=2,
shuffle=False, collate_fn=dataset.collate_fn, # shuffle=False, collate_fn=dataset.collate_fn,
drop_last=False, num_workers=c.num_loader_workers) # drop_last=False, num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): # for i, data in enumerate(dataloader):
if i == self.max_loader_iter: # if i == self.max_loader_iter:
break # break
text_input = data[0] # text_input = data[0]
text_lengths = data[1] # text_lengths = data[1]
linear_input = data[2] # linear_input = data[2]
mel_input = data[3] # mel_input = data[3]
mel_lengths = data[4] # mel_lengths = data[4]
stop_target = data[5] # stop_target = data[5]
item_idx = data[6] # item_idx = data[6]
if mel_lengths[0] > mel_lengths[1]: # if mel_lengths[0] > mel_lengths[1]:
idx = 0 # idx = 0
else: # else:
idx = 1 # idx = 1
# check the first item in the batch # # check the first item in the batch
assert mel_input[idx, -1].sum() == 0 # assert mel_input[idx, -1].sum() == 0
assert mel_input[idx, -2].sum() != 0, mel_input # assert mel_input[idx, -2].sum() != 0, mel_input
assert linear_input[idx, -1].sum() == 0 # assert linear_input[idx, -1].sum() == 0
assert linear_input[idx, -2].sum() != 0 # assert linear_input[idx, -2].sum() != 0
assert stop_target[idx, -1] == 1 # assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0 # assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1 # assert stop_target[idx].sum() == 1
assert len(mel_lengths.shape) == 1 # assert len(mel_lengths.shape) == 1
assert mel_lengths[idx] == mel_input[idx].shape[0] # assert mel_lengths[idx] == mel_input[idx].shape[0]
# check the second itme in the batch # # check the second itme in the batch
assert mel_input[1-idx, -1].sum() == 0 # assert mel_input[1-idx, -1].sum() == 0
assert linear_input[1-idx, -1].sum() == 0 # assert linear_input[1-idx, -1].sum() == 0
assert stop_target[1-idx, -1] == 1 # assert stop_target[1-idx, -1] == 1
assert len(mel_lengths.shape) == 1 # assert len(mel_lengths.shape) == 1
# check batch conditions # # check batch conditions
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 # assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 # assert (linear_input * stop_target.unsqueeze(2)).sum() == 0

View File

@ -42,20 +42,21 @@ class TacotronTrainTest(unittest.TestCase):
count += 1 count += 1
optimizer = optim.Adam(model.parameters(), lr=c.lr) optimizer = optim.Adam(model.parameters(), lr=c.lr)
for i in range(5): for i in range(5):
mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec) mel_out, linear_out, align = model.forward(input, mel_spec)
assert stop_tokens.data.max() <= 1.0 # mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec)
assert stop_tokens.data.min() >= 0.0 # assert stop_tokens.data.max() <= 1.0
# assert stop_tokens.data.min() >= 0.0
optimizer.zero_grad() optimizer.zero_grad()
loss = criterion(mel_out, mel_spec, mel_lengths) loss = criterion(mel_out, mel_spec, mel_lengths)
stop_loss = criterion_st(stop_tokens, stop_targets) # stop_loss = criterion_st(stop_tokens, stop_targets)
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss loss = loss + criterion(linear_out, linear_spec, mel_lengths)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# check parameter changes # check parameter changes
count = 0 count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()): for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional # ignore pre-higway layer since it works conditional
if count not in [141, 59]: if count not in [139, 59]:
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref) assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref)
count += 1 count += 1