Update tests

This commit is contained in:
Eren Golge 2019-03-26 00:48:35 +01:00
parent 0a92c6d5a7
commit 09b1a7b612
2 changed files with 2 additions and 6 deletions

View File

@ -48,8 +48,6 @@ class DecoderTests(unittest.TestCase):
assert output.shape[1] == 1, "size not {}".format(output.shape[1])
assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2])
assert stop_tokens.shape[0] == 4
assert stop_tokens.max() <= 1.0
assert stop_tokens.min() >= 0
class EncoderTests(unittest.TestCase):

View File

@ -33,10 +33,10 @@ class TacotronTrainTest(unittest.TestCase):
stop_targets = stop_targets.view(input.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
criterion = L1LossMasked().to(device)
criterion_st = nn.BCELoss().to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
model = Tacotron(32, c.audio['num_freq'], c.audio['num_mels'],
c.r, memory_size=c.memory_size).to(device)
model.train()
@ -50,8 +50,6 @@ class TacotronTrainTest(unittest.TestCase):
for i in range(5):
mel_out, linear_out, align, stop_tokens = model.forward(
input, input_lengths, mel_spec)
assert stop_tokens.data.max() <= 1.0
assert stop_tokens.data.min() >= 0.0
optimizer.zero_grad()
loss = criterion(mel_out, mel_spec, mel_lengths)
stop_loss = criterion_st(stop_tokens, stop_targets)