fix unittests for the latest updates

This commit is contained in:
Eren Golge 2019-07-19 11:12:48 +02:00
parent d5cefe0546
commit 4336e1d338
8 changed files with 54 additions and 36 deletions

View File

@ -13,7 +13,6 @@ from utils.data import (prepare_data, pad_per_step, prepare_tensor,
class MyDataset(Dataset):
def __init__(self,
root_path,
outputs_per_step,
text_cleaner,
ap,
@ -28,13 +27,10 @@ class MyDataset(Dataset):
verbose=False):
"""
Args:
root_path (str): root path for the data folder.
outputs_per_step (int): number of time frames predicted per step.
text_cleaner (str): text cleaner used for the dataset.
ap (TTS.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances.
speaker_id_cache_path (str): path where the speaker name to id
mapping is stored
batch_group_size (int): (0) range of batch randomization after sorting
sequences by length.
min_seq_len (int): (0) minimum sequence length to be processed
@ -47,7 +43,6 @@ class MyDataset(Dataset):
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
verbose (bool): print diagnostic information.
"""
self.root_path = root_path
self.batch_group_size = batch_group_size
self.items = meta_data
self.outputs_per_step = outputs_per_step
@ -65,7 +60,6 @@ class MyDataset(Dataset):
os.makedirs(phoneme_cache_path, exist_ok=True)
if self.verbose:
print("\n > DataLoader initialization")
print(" | > Data path: {}".format(root_path))
print(" | > Use phonemes: {}".format(self.use_phonemes))
if use_phonemes:
print(" | > phoneme language: {}".format(phoneme_language))

View File

@ -70,6 +70,8 @@ class Tacotron(nn.Module):
return mel_outputs, linear_outputs, alignments, stop_tokens
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
if hasattr(self, "speaker_embedding") and speaker_ids is None:
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
speaker_embeddings = self.speaker_embedding(speaker_ids)

View File

@ -89,7 +89,9 @@ class Tacotron2(nn.Module):
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
if hasattr(self, "speaker_embedding") and speaker_ids is None:
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
elif hasattr(self, "speaker_embedding") and speaker_ids is not None:
speaker_embeddings = self.speaker_embedding(speaker_ids)
speaker_embeddings.unsqueeze_(1)

View File

@ -38,11 +38,25 @@ class CBHGTests(unittest.TestCase):
class DecoderTests(unittest.TestCase):
def test_in_out(self):
layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False, attn_norm="sigmoid")
layer = Decoder(
in_features=256,
memory_dim=80,
r=2,
memory_size=4,
attn_windowing=False,
attn_norm="sigmoid",
prenet_type='original',
prenet_dropout=True,
forward_attn=True,
trans_agent=True,
forward_attn_mask=True,
location_attn=True,
separate_stopnet=True)
dummy_input = T.rand(4, 8, 256)
dummy_memory = T.rand(4, 2, 80)
output, alignment, stop_tokens = layer(dummy_input, dummy_memory, mask=None)
output, alignment, stop_tokens = layer(
dummy_input, dummy_memory, mask=None)
assert output.shape[0] == 4
assert output.shape[1] == 1, "size not {}".format(output.shape[1])

View File

@ -29,13 +29,12 @@ class TestTTSDataset(unittest.TestCase):
self.ap = AudioProcessor(**c.audio)
def _create_dataloader(self, batch_size, r, bgs):
items = ljspeech(c.data_path,'metadata.csv')
dataset = TTSDataset.MyDataset(
c.data_path,
'metadata.csv',
r,
c.text_cleaner,
preprocessor=ljspeech,
ap=self.ap,
meta_data=items,
batch_group_size=bgs,
min_seq_len=c.min_seq_len,
max_seq_len=float("inf"),
@ -58,17 +57,19 @@ class TestTTSDataset(unittest.TestCase):
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert type(speaker_name[0]) is str
assert linear_input.shape[0] == c.batch_size
assert linear_input.shape[2] == self.ap.num_freq
assert mel_input.shape[0] == c.batch_size
@ -92,11 +93,12 @@ class TestTTSDataset(unittest.TestCase):
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
avg_length = mel_lengths.numpy().mean()
assert avg_length >= last_length
@ -112,11 +114,12 @@ class TestTTSDataset(unittest.TestCase):
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
# check mel_spec consistency
wav = self.ap.load_wav(item_idx[0])
@ -159,11 +162,12 @@ class TestTTSDataset(unittest.TestCase):
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
if mel_lengths[0] > mel_lengths[1]:
idx = 0

View File

@ -27,6 +27,7 @@ class TacotronTrainTest(unittest.TestCase):
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
for idx in mel_lengths:
stop_targets[:, int(idx.item()):, 0] = 1.0
@ -37,7 +38,7 @@ class TacotronTrainTest(unittest.TestCase):
criterion = MSELossMasked().to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
model = Tacotron2(24, c.r).to(device)
model = Tacotron2(24, c.r, 5).to(device)
model.train()
model_ref = copy.deepcopy(model)
count = 0
@ -48,7 +49,7 @@ class TacotronTrainTest(unittest.TestCase):
optimizer = optim.Adam(model.parameters(), lr=c.lr)
for i in range(5):
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
input, input_lengths, mel_spec)
input, input_lengths, mel_spec, speaker_ids)
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
optimizer.zero_grad()

View File

@ -32,6 +32,7 @@ class TacotronTrainTest(unittest.TestCase):
linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
for idx in mel_lengths:
stop_targets[:, int(idx.item()):, 0] = 1.0
@ -45,6 +46,7 @@ class TacotronTrainTest(unittest.TestCase):
criterion_st = nn.BCEWithLogitsLoss().to(device)
model = Tacotron(
32,
5,
linear_dim=c.audio['num_freq'],
mel_dim=c.audio['num_mels'],
r=c.r,
@ -60,7 +62,7 @@ class TacotronTrainTest(unittest.TestCase):
optimizer = optim.Adam(model.parameters(), lr=c.lr)
for i in range(5):
mel_out, linear_out, align, stop_tokens = model.forward(
input, input_lengths, mel_spec)
input, input_lengths, mel_spec, speaker_ids)
optimizer.zero_grad()
loss = criterion(mel_out, mel_spec, mel_lengths)
stop_loss = criterion_st(stop_tokens, stop_targets)

View File

@ -58,7 +58,6 @@ def setup_loader(is_val=False, verbose=False):
loader = None
else:
dataset = MyDataset(
c.data_path,
c.r,
c.text_cleaner,
meta_data=meta_data_eval if is_val else meta_data_train,