mirror of https://github.com/coqui-ai/TTS.git
test(helpers): fix test_rand_segment, test_generate_path
This commit is contained in:
parent
c9f7197862
commit
857cd55ce5
|
@ -53,8 +53,8 @@ def test_segment():
|
|||
def test_rand_segments():
|
||||
x = T.rand(2, 3, 4)
|
||||
x_lens = T.randint(3, 4, (2,))
|
||||
segments, seg_idxs = rand_segments(x, x_lens, segment_size=3)
|
||||
assert segments.shape == (2, 3, 3)
|
||||
segments, seg_idxs = rand_segments(x, x_lens, segment_size=2)
|
||||
assert segments.shape == (2, 3, 2)
|
||||
assert all(seg_idxs >= 0), seg_idxs
|
||||
try:
|
||||
segments, _ = rand_segments(x, x_lens, segment_size=5)
|
||||
|
@ -71,7 +71,7 @@ def test_rand_segments():
|
|||
def test_generate_path():
|
||||
durations = T.randint(1, 4, (10, 21))
|
||||
x_length = T.randint(18, 22, (10,))
|
||||
x_mask = sequence_mask(x_length).unsqueeze(1).long()
|
||||
x_mask = sequence_mask(x_length, max_len=21).unsqueeze(1).long()
|
||||
durations = durations * x_mask.squeeze(1)
|
||||
y_length = durations.sum(1)
|
||||
y_mask = sequence_mask(y_length).unsqueeze(1).long()
|
||||
|
|
Loading…
Reference in New Issue