test(helpers): add test_ prefix so tests actually run

This commit is contained in:
Enno Hermann 2024-06-20 14:16:45 +02:00
parent 98c0f86cb3
commit c9f7197862
1 changed files with 5 additions and 5 deletions

View File

@ -3,7 +3,7 @@ import torch as T
from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask
def average_over_durations_test(): # pylint: disable=no-self-use
def test_average_over_durations(): # pylint: disable=no-self-use
pitch = T.rand(1, 1, 128)
durations = T.randint(1, 5, (1, 21))
@ -21,7 +21,7 @@ def average_over_durations_test(): # pylint: disable=no-self-use
index += dur
def seqeunce_mask_test():
def test_sequence_mask():
lengths = T.randint(10, 15, (8,))
mask = sequence_mask(lengths)
for i in range(8):
@ -30,7 +30,7 @@ def seqeunce_mask_test():
assert mask[i, l:].sum() == 0
def segment_test():
def test_segment():
x = T.range(0, 11)
x = x.repeat(8, 1).unsqueeze(1)
segment_ids = T.randint(0, 7, (8,))
@ -50,7 +50,7 @@ def segment_test():
assert x[idx, :, start_indx : start_indx + 10].sum() == segments[idx, :, :].sum()
def rand_segments_test():
def test_rand_segments():
x = T.rand(2, 3, 4)
x_lens = T.randint(3, 4, (2,))
segments, seg_idxs = rand_segments(x, x_lens, segment_size=3)
@ -68,7 +68,7 @@ def rand_segments_test():
assert all(x_lens_back == x_lens)
def generate_path_test():
def test_generate_path():
durations = T.randint(1, 4, (10, 21))
x_length = T.randint(18, 22, (10,))
x_mask = sequence_mask(x_length).unsqueeze(1).long()