update pwgan generator test

This commit is contained in:
erogol 2020-07-17 13:01:45 +02:00
parent 7bfe16f130
commit 9ce4126482
1 changed files with 3 additions and 3 deletions

View File

@ -22,9 +22,9 @@ def test_pwgan_generator():
use_causal_conv=False,
upsample_conditional_features=True,
upsample_factors=[4, 4, 4, 4])
dummy_c = torch.rand((4, 80, 64))
dummy_c = torch.rand((2, 80, 5))
output = model(dummy_c)
assert np.all(output.shape == (4, 1, 64 * 256))
assert np.all(output.shape == (2, 1, 5 * 256)), output.shape
model.remove_weight_norm()
output = model.inference(dummy_c)
assert np.all(output.shape == (4, 1, (64 + 4) * 256))
assert np.all(output.shape == (2, 1, (5 + 4) * 256))