From b0bad56ba95c32577f5a2e272a14b8784b2364f1 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 10 Mar 2022 17:02:37 -0300 Subject: [PATCH] Add Perfect Batch Sampler unit test --- TTS/encoder/utils/samplers.py | 4 +-- tests/data_tests/test_samplers.py | 49 +++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/TTS/encoder/utils/samplers.py b/TTS/encoder/utils/samplers.py index 935aa067..947f5da0 100644 --- a/TTS/encoder/utils/samplers.py +++ b/TTS/encoder/utils/samplers.py @@ -34,14 +34,14 @@ class PerfectBatchSampler(Sampler): drop_last (bool): if True, drops last incomplete batch. """ - def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False): + def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False, label_key="class_name"): super().__init__(dataset_items) assert batch_size % (num_classes_in_batch * num_gpus) == 0, ( 'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).') label_indices = {} for idx, item in enumerate(dataset_items): - label = item['class_name'] + label = item[label_key] if label not in label_indices.keys(): label_indices[label] = [idx] else: diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py index 12152fb8..c888c629 100644 --- a/tests/data_tests/test_samplers.py +++ b/tests/data_tests/test_samplers.py @@ -8,6 +8,7 @@ from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.languages import get_language_balancer_weights from TTS.tts.utils.speakers import get_speaker_balancer_weights +from TTS.encoder.utils.samplers import PerfectBatchSampler # Fixing random state to avoid random fails torch.manual_seed(0) @@ -82,3 +83,51 @@ class TestSamplers(unittest.TestCase): spk2 += 1 assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced" + + def test_perfect_sampler(self): # pylint: disable=no-self-use + classes = set() + for item in train_samples: + classes.add(item["speaker_name"]) + + sampler = PerfectBatchSampler( + train_samples, + classes, + batch_size=2 * 3, # total batch size + num_classes_in_batch=2, + label_key="speaker_name", + shuffle=False, + drop_last=True) + batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)]) + for batch in batchs: + spk1, spk2 = 0, 0 + # for in each batch + for index in batch: + if train_samples[index]["speaker_name"] == "ljspeech-0": + spk1 += 1 + else: + spk2 += 1 + assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced" + + def test_perfect_sampler_shuffle(self): # pylint: disable=no-self-use + classes = set() + for item in train_samples: + classes.add(item["speaker_name"]) + + sampler = PerfectBatchSampler( + train_samples, + classes, + batch_size=2 * 3, # total batch size + num_classes_in_batch=2, + label_key="speaker_name", + shuffle=True, + drop_last=False) + batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)]) + for batch in batchs: + spk1, spk2 = 0, 0 + # for in each batch + for index in batch: + if train_samples[index]["speaker_name"] == "ljspeech-0": + spk1 += 1 + else: + spk2 += 1 + assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"