From 857cd55ce5b63f554bafc797176b08d750e8fc9e Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 20 Jun 2024 14:28:44 +0200 Subject: [PATCH] test(helpers): fix test_rand_segment, test_generate_path --- tests/tts_tests/test_helpers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tts_tests/test_helpers.py b/tests/tts_tests/test_helpers.py index a83ec9dd..dbd7f54e 100644 --- a/tests/tts_tests/test_helpers.py +++ b/tests/tts_tests/test_helpers.py @@ -53,8 +53,8 @@ def test_segment(): def test_rand_segments(): x = T.rand(2, 3, 4) x_lens = T.randint(3, 4, (2,)) - segments, seg_idxs = rand_segments(x, x_lens, segment_size=3) - assert segments.shape == (2, 3, 3) + segments, seg_idxs = rand_segments(x, x_lens, segment_size=2) + assert segments.shape == (2, 3, 2) assert all(seg_idxs >= 0), seg_idxs try: segments, _ = rand_segments(x, x_lens, segment_size=5) @@ -71,7 +71,7 @@ def test_rand_segments(): def test_generate_path(): durations = T.randint(1, 4, (10, 21)) x_length = T.randint(18, 22, (10,)) - x_mask = sequence_mask(x_length).unsqueeze(1).long() + x_mask = sequence_mask(x_length, max_len=21).unsqueeze(1).long() durations = durations * x_mask.squeeze(1) y_length = durations.sum(1) y_mask = sequence_mask(y_length).unsqueeze(1).long()