mirror of https://github.com/coqui-ai/TTS.git
tests update
This commit is contained in:
parent
14c9e9cde9
commit
b087c0b5ec
|
@ -5,7 +5,7 @@ import numpy as np
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from TTS.utils.generic_utils import load_config
|
from TTS.utils.generic_utils import load_config
|
||||||
from TTS.datasets.LJSpeech import LJSpeechDataset
|
from TTS.datasets.LJSpeech import LJSpeechDataset
|
||||||
from TTS.datasets.TWEB import TWEBDataset
|
# from TTS.datasets.TWEB import TWEBDataset
|
||||||
|
|
||||||
|
|
||||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
@ -19,8 +19,8 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
self.max_loader_iter = 4
|
self.max_loader_iter = 4
|
||||||
|
|
||||||
def test_loader(self):
|
def test_loader(self):
|
||||||
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||||
os.path.join(c.data_path_LJSpeech, 'wavs'),
|
os.path.join(c.data_path, 'wavs'),
|
||||||
c.r,
|
c.r,
|
||||||
c.sample_rate,
|
c.sample_rate,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
|
@ -59,8 +59,8 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
assert mel_input.shape[2] == c.num_mels
|
assert mel_input.shape[2] == c.num_mels
|
||||||
|
|
||||||
def test_padding(self):
|
def test_padding(self):
|
||||||
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||||
os.path.join(c.data_path_LJSpeech, 'wavs'),
|
os.path.join(c.data_path, 'wavs'),
|
||||||
1,
|
1,
|
||||||
c.sample_rate,
|
c.sample_rate,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
|
@ -144,134 +144,134 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
|
|
||||||
|
|
||||||
class TestTWEBDataset(unittest.TestCase):
|
# class TestTWEBDataset(unittest.TestCase):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
# def __init__(self, *args, **kwargs):
|
||||||
super(TestTWEBDataset, self).__init__(*args, **kwargs)
|
# super(TestTWEBDataset, self).__init__(*args, **kwargs)
|
||||||
self.max_loader_iter = 4
|
# self.max_loader_iter = 4
|
||||||
|
|
||||||
def test_loader(self):
|
# def test_loader(self):
|
||||||
dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
|
# dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
|
||||||
os.path.join(c.data_path_TWEB, 'wavs'),
|
# os.path.join(c.data_path_TWEB, 'wavs'),
|
||||||
c.r,
|
# c.r,
|
||||||
c.sample_rate,
|
# c.sample_rate,
|
||||||
c.text_cleaner,
|
# c.text_cleaner,
|
||||||
c.num_mels,
|
# c.num_mels,
|
||||||
c.min_level_db,
|
# c.min_level_db,
|
||||||
c.frame_shift_ms,
|
# c.frame_shift_ms,
|
||||||
c.frame_length_ms,
|
# c.frame_length_ms,
|
||||||
c.preemphasis,
|
# c.preemphasis,
|
||||||
c.ref_level_db,
|
# c.ref_level_db,
|
||||||
c.num_freq,
|
# c.num_freq,
|
||||||
c.power
|
# c.power
|
||||||
)
|
# )
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, batch_size=2,
|
# dataloader = DataLoader(dataset, batch_size=2,
|
||||||
shuffle=True, collate_fn=dataset.collate_fn,
|
# shuffle=True, collate_fn=dataset.collate_fn,
|
||||||
drop_last=True, num_workers=c.num_loader_workers)
|
# drop_last=True, num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
# for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
# if i == self.max_loader_iter:
|
||||||
break
|
# break
|
||||||
text_input = data[0]
|
# text_input = data[0]
|
||||||
text_lengths = data[1]
|
# text_lengths = data[1]
|
||||||
linear_input = data[2]
|
# linear_input = data[2]
|
||||||
mel_input = data[3]
|
# mel_input = data[3]
|
||||||
mel_lengths = data[4]
|
# mel_lengths = data[4]
|
||||||
stop_target = data[5]
|
# stop_target = data[5]
|
||||||
item_idx = data[6]
|
# item_idx = data[6]
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
# neg_values = text_input[text_input < 0]
|
||||||
check_count = len(neg_values)
|
# check_count = len(neg_values)
|
||||||
assert check_count == 0, \
|
# assert check_count == 0, \
|
||||||
" !! Negative values in text_input: {}".format(check_count)
|
# " !! Negative values in text_input: {}".format(check_count)
|
||||||
# TODO: more assertion here
|
# # TODO: more assertion here
|
||||||
assert linear_input.shape[0] == c.batch_size
|
# assert linear_input.shape[0] == c.batch_size
|
||||||
assert mel_input.shape[0] == c.batch_size
|
# assert mel_input.shape[0] == c.batch_size
|
||||||
assert mel_input.shape[2] == c.num_mels
|
# assert mel_input.shape[2] == c.num_mels
|
||||||
|
|
||||||
def test_padding(self):
|
# def test_padding(self):
|
||||||
dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
|
# dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
|
||||||
os.path.join(c.data_path_TWEB, 'wavs'),
|
# os.path.join(c.data_path_TWEB, 'wavs'),
|
||||||
1,
|
# 1,
|
||||||
c.sample_rate,
|
# c.sample_rate,
|
||||||
c.text_cleaner,
|
# c.text_cleaner,
|
||||||
c.num_mels,
|
# c.num_mels,
|
||||||
c.min_level_db,
|
# c.min_level_db,
|
||||||
c.frame_shift_ms,
|
# c.frame_shift_ms,
|
||||||
c.frame_length_ms,
|
# c.frame_length_ms,
|
||||||
c.preemphasis,
|
# c.preemphasis,
|
||||||
c.ref_level_db,
|
# c.ref_level_db,
|
||||||
c.num_freq,
|
# c.num_freq,
|
||||||
c.power
|
# c.power
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Test for batch size 1
|
# # Test for batch size 1
|
||||||
dataloader = DataLoader(dataset, batch_size=1,
|
# dataloader = DataLoader(dataset, batch_size=1,
|
||||||
shuffle=False, collate_fn=dataset.collate_fn,
|
# shuffle=False, collate_fn=dataset.collate_fn,
|
||||||
drop_last=False, num_workers=c.num_loader_workers)
|
# drop_last=False, num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
# for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
# if i == self.max_loader_iter:
|
||||||
break
|
# break
|
||||||
|
|
||||||
text_input = data[0]
|
# text_input = data[0]
|
||||||
text_lengths = data[1]
|
# text_lengths = data[1]
|
||||||
linear_input = data[2]
|
# linear_input = data[2]
|
||||||
mel_input = data[3]
|
# mel_input = data[3]
|
||||||
mel_lengths = data[4]
|
# mel_lengths = data[4]
|
||||||
stop_target = data[5]
|
# stop_target = data[5]
|
||||||
item_idx = data[6]
|
# item_idx = data[6]
|
||||||
|
|
||||||
# check the last time step to be zero padded
|
# # check the last time step to be zero padded
|
||||||
assert mel_input[0, -1].sum() == 0
|
# assert mel_input[0, -1].sum() == 0
|
||||||
assert mel_input[0, -2].sum() != 0, "{} -- {}".format(item_idx, i)
|
# assert mel_input[0, -2].sum() != 0, "{} -- {}".format(item_idx, i)
|
||||||
assert linear_input[0, -1].sum() == 0
|
# assert linear_input[0, -1].sum() == 0
|
||||||
assert linear_input[0, -2].sum() != 0
|
# assert linear_input[0, -2].sum() != 0
|
||||||
assert stop_target[0, -1] == 1
|
# assert stop_target[0, -1] == 1
|
||||||
assert stop_target[0, -2] == 0
|
# assert stop_target[0, -2] == 0
|
||||||
assert stop_target.sum() == 1
|
# assert stop_target.sum() == 1
|
||||||
assert len(mel_lengths.shape) == 1
|
# assert len(mel_lengths.shape) == 1
|
||||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
# assert mel_lengths[0] == mel_input[0].shape[0]
|
||||||
|
|
||||||
# Test for batch size 2
|
# # Test for batch size 2
|
||||||
dataloader = DataLoader(dataset, batch_size=2,
|
# dataloader = DataLoader(dataset, batch_size=2,
|
||||||
shuffle=False, collate_fn=dataset.collate_fn,
|
# shuffle=False, collate_fn=dataset.collate_fn,
|
||||||
drop_last=False, num_workers=c.num_loader_workers)
|
# drop_last=False, num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
# for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
# if i == self.max_loader_iter:
|
||||||
break
|
# break
|
||||||
text_input = data[0]
|
# text_input = data[0]
|
||||||
text_lengths = data[1]
|
# text_lengths = data[1]
|
||||||
linear_input = data[2]
|
# linear_input = data[2]
|
||||||
mel_input = data[3]
|
# mel_input = data[3]
|
||||||
mel_lengths = data[4]
|
# mel_lengths = data[4]
|
||||||
stop_target = data[5]
|
# stop_target = data[5]
|
||||||
item_idx = data[6]
|
# item_idx = data[6]
|
||||||
|
|
||||||
if mel_lengths[0] > mel_lengths[1]:
|
# if mel_lengths[0] > mel_lengths[1]:
|
||||||
idx = 0
|
# idx = 0
|
||||||
else:
|
# else:
|
||||||
idx = 1
|
# idx = 1
|
||||||
|
|
||||||
# check the first item in the batch
|
# # check the first item in the batch
|
||||||
assert mel_input[idx, -1].sum() == 0
|
# assert mel_input[idx, -1].sum() == 0
|
||||||
assert mel_input[idx, -2].sum() != 0, mel_input
|
# assert mel_input[idx, -2].sum() != 0, mel_input
|
||||||
assert linear_input[idx, -1].sum() == 0
|
# assert linear_input[idx, -1].sum() == 0
|
||||||
assert linear_input[idx, -2].sum() != 0
|
# assert linear_input[idx, -2].sum() != 0
|
||||||
assert stop_target[idx, -1] == 1
|
# assert stop_target[idx, -1] == 1
|
||||||
assert stop_target[idx, -2] == 0
|
# assert stop_target[idx, -2] == 0
|
||||||
assert stop_target[idx].sum() == 1
|
# assert stop_target[idx].sum() == 1
|
||||||
assert len(mel_lengths.shape) == 1
|
# assert len(mel_lengths.shape) == 1
|
||||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
# assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||||
|
|
||||||
# check the second itme in the batch
|
# # check the second itme in the batch
|
||||||
assert mel_input[1-idx, -1].sum() == 0
|
# assert mel_input[1-idx, -1].sum() == 0
|
||||||
assert linear_input[1-idx, -1].sum() == 0
|
# assert linear_input[1-idx, -1].sum() == 0
|
||||||
assert stop_target[1-idx, -1] == 1
|
# assert stop_target[1-idx, -1] == 1
|
||||||
assert len(mel_lengths.shape) == 1
|
# assert len(mel_lengths.shape) == 1
|
||||||
|
|
||||||
# check batch conditions
|
# # check batch conditions
|
||||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
|
|
|
@ -42,20 +42,21 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec)
|
mel_out, linear_out, align = model.forward(input, mel_spec)
|
||||||
assert stop_tokens.data.max() <= 1.0
|
# mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec)
|
||||||
assert stop_tokens.data.min() >= 0.0
|
# assert stop_tokens.data.max() <= 1.0
|
||||||
|
# assert stop_tokens.data.min() >= 0.0
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
# stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss
|
loss = loss + criterion(linear_out, linear_spec, mel_lengths)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
# ignore pre-higway layer since it works conditional
|
# ignore pre-higway layer since it works conditional
|
||||||
if count not in [141, 59]:
|
if count not in [139, 59]:
|
||||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref)
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue