Add Perfect Batch Sampler unit test

This commit is contained in:
Edresson Casanova 2022-03-10 17:02:37 -03:00
parent 9c8b8201c3
commit b0bad56ba9
2 changed files with 51 additions and 2 deletions

View File

@ -34,14 +34,14 @@ class PerfectBatchSampler(Sampler):
drop_last (bool): if True, drops last incomplete batch. drop_last (bool): if True, drops last incomplete batch.
""" """
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False): def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False, label_key="class_name"):
super().__init__(dataset_items) super().__init__(dataset_items)
assert batch_size % (num_classes_in_batch * num_gpus) == 0, ( assert batch_size % (num_classes_in_batch * num_gpus) == 0, (
'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).') 'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).')
label_indices = {} label_indices = {}
for idx, item in enumerate(dataset_items): for idx, item in enumerate(dataset_items):
label = item['class_name'] label = item[label_key]
if label not in label_indices.keys(): if label not in label_indices.keys():
label_indices[label] = [idx] label_indices[label] = [idx]
else: else:

View File

@ -8,6 +8,7 @@ from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.languages import get_language_balancer_weights from TTS.tts.utils.languages import get_language_balancer_weights
from TTS.tts.utils.speakers import get_speaker_balancer_weights from TTS.tts.utils.speakers import get_speaker_balancer_weights
from TTS.encoder.utils.samplers import PerfectBatchSampler
# Fixing random state to avoid random fails # Fixing random state to avoid random fails
torch.manual_seed(0) torch.manual_seed(0)
@ -82,3 +83,51 @@ class TestSamplers(unittest.TestCase):
spk2 += 1 spk2 += 1
assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced" assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"
def test_perfect_sampler(self): # pylint: disable=no-self-use
classes = set()
for item in train_samples:
classes.add(item["speaker_name"])
sampler = PerfectBatchSampler(
train_samples,
classes,
batch_size=2 * 3, # total batch size
num_classes_in_batch=2,
label_key="speaker_name",
shuffle=False,
drop_last=True)
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
for batch in batchs:
spk1, spk2 = 0, 0
# for in each batch
for index in batch:
if train_samples[index]["speaker_name"] == "ljspeech-0":
spk1 += 1
else:
spk2 += 1
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
def test_perfect_sampler_shuffle(self): # pylint: disable=no-self-use
classes = set()
for item in train_samples:
classes.add(item["speaker_name"])
sampler = PerfectBatchSampler(
train_samples,
classes,
batch_size=2 * 3, # total batch size
num_classes_in_batch=2,
label_key="speaker_name",
shuffle=True,
drop_last=False)
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
for batch in batchs:
spk1, spk2 = 0, 0
# for in each batch
for index in batch:
if train_samples[index]["speaker_name"] == "ljspeech-0":
spk1 += 1
else:
spk2 += 1
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"