mirror of https://github.com/coqui-ai/TTS.git
Update tests
This commit is contained in:
parent
0a92c6d5a7
commit
09b1a7b612
|
@ -48,8 +48,6 @@ class DecoderTests(unittest.TestCase):
|
|||
assert output.shape[1] == 1, "size not {}".format(output.shape[1])
|
||||
assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2])
|
||||
assert stop_tokens.shape[0] == 4
|
||||
assert stop_tokens.max() <= 1.0
|
||||
assert stop_tokens.min() >= 0
|
||||
|
||||
|
||||
class EncoderTests(unittest.TestCase):
|
||||
|
|
|
@ -33,10 +33,10 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
|
||||
stop_targets = stop_targets.view(input.shape[0],
|
||||
stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||
|
||||
criterion = L1LossMasked().to(device)
|
||||
criterion_st = nn.BCELoss().to(device)
|
||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||
model = Tacotron(32, c.audio['num_freq'], c.audio['num_mels'],
|
||||
c.r, memory_size=c.memory_size).to(device)
|
||||
model.train()
|
||||
|
@ -50,8 +50,6 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
for i in range(5):
|
||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||
input, input_lengths, mel_spec)
|
||||
assert stop_tokens.data.max() <= 1.0
|
||||
assert stop_tokens.data.min() >= 0.0
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
|
|
Loading…
Reference in New Issue