Add unit tests

This commit is contained in:
Edresson Casanova 2022-05-07 15:56:39 -03:00
parent e86e3d2e87
commit 002f826a1c
2 changed files with 29 additions and 2 deletions

View File

@ -62,8 +62,8 @@ def get_length_balancer_weights(items: list, num_buckets=10):
# create the $num_buckets buckets classes based in the dataset max and min length
max_length = int(max(audio_lengths))
min_length = int(min(audio_lengths))
step = int((max_length - min_length) / num_buckets)
buckets_classes = [i + step for i in range(min_length, max_length, step)]
step = int((max_length - min_length) / num_buckets) + 1
buckets_classes = [i + step for i in range(min_length, (max_length - step) + num_buckets + 1, step)]
# add each sample in their respective length bucket
buckets_names = np.array(
[buckets_classes[bisect.bisect_left(buckets_classes, item["audio_length"])] for item in items]

View File

@ -1,4 +1,5 @@
import functools
import random
import unittest
import torch
@ -6,6 +7,7 @@ import torch
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.encoder.utils.samplers import PerfectBatchSampler
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.languages import get_language_balancer_weights
from TTS.tts.utils.speakers import get_speaker_balancer_weights
@ -136,3 +138,28 @@ class TestSamplers(unittest.TestCase):
else:
spk2 += 1
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
def test_length_weighted_random_sampler(self): # pylint: disable=no-self-use
for _ in range(1000):
# gerenate a lenght unbalanced dataset with random max/min audio lenght
min_audio = random.randrange(1, 22050)
max_audio = random.randrange(44100, 220500)
for idx, item in enumerate(train_samples):
# increase the diversity of durations
random_increase = random.randrange(100, 1000)
if idx < 5:
item["audio_length"] = min_audio + random_increase
else:
item["audio_length"] = max_audio + random_increase
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
get_length_balancer_weights(train_samples, num_buckets=2), len(train_samples)
)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
len1, len2 = 0, 0
for index in ids:
if train_samples[index]["audio_length"] < max_audio:
len1 += 1
else:
len2 += 1
assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"