From 002f826a1c5110d4ae92d63a3bae88de92a2212f Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Sat, 7 May 2022 15:56:39 -0300 Subject: [PATCH] Add unit tests --- TTS/tts/utils/data.py | 4 ++-- tests/data_tests/test_samplers.py | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/TTS/tts/utils/data.py b/TTS/tts/utils/data.py index c0744177..22e46b68 100644 --- a/TTS/tts/utils/data.py +++ b/TTS/tts/utils/data.py @@ -62,8 +62,8 @@ def get_length_balancer_weights(items: list, num_buckets=10): # create the $num_buckets buckets classes based in the dataset max and min length max_length = int(max(audio_lengths)) min_length = int(min(audio_lengths)) - step = int((max_length - min_length) / num_buckets) - buckets_classes = [i + step for i in range(min_length, max_length, step)] + step = int((max_length - min_length) / num_buckets) + 1 + buckets_classes = [i + step for i in range(min_length, (max_length - step) + num_buckets + 1, step)] # add each sample in their respective length bucket buckets_names = np.array( [buckets_classes[bisect.bisect_left(buckets_classes, item["audio_length"])] for item in items] diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py index 42f1bfd5..b85e0ec4 100644 --- a/tests/data_tests/test_samplers.py +++ b/tests/data_tests/test_samplers.py @@ -1,4 +1,5 @@ import functools +import random import unittest import torch @@ -6,6 +7,7 @@ import torch from TTS.config.shared_configs import BaseDatasetConfig from TTS.encoder.utils.samplers import PerfectBatchSampler from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.data import get_length_balancer_weights from TTS.tts.utils.languages import get_language_balancer_weights from TTS.tts.utils.speakers import get_speaker_balancer_weights @@ -136,3 +138,28 @@ class TestSamplers(unittest.TestCase): else: spk2 += 1 assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced" + + def test_length_weighted_random_sampler(self): # pylint: disable=no-self-use + for _ in range(1000): + # gerenate a lenght unbalanced dataset with random max/min audio lenght + min_audio = random.randrange(1, 22050) + max_audio = random.randrange(44100, 220500) + for idx, item in enumerate(train_samples): + # increase the diversity of durations + random_increase = random.randrange(100, 1000) + if idx < 5: + item["audio_length"] = min_audio + random_increase + else: + item["audio_length"] = max_audio + random_increase + + weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler( + get_length_balancer_weights(train_samples, num_buckets=2), len(train_samples) + ) + ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) + len1, len2 = 0, 0 + for index in ids: + if train_samples[index]["audio_length"] < max_audio: + len1 += 1 + else: + len2 += 1 + assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"