Add Perfect Batch Sampler unit test

This commit is contained in:
Edresson Casanova 2022-03-10 17:02:37 -03:00
parent 9c8b8201c3
commit b0bad56ba9
2 changed files with 51 additions and 2 deletions

View File

@ -34,14 +34,14 @@ class PerfectBatchSampler(Sampler):
drop_last (bool): if True, drops last incomplete batch.
"""
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False):
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False, label_key="class_name"):
super().__init__(dataset_items)
assert batch_size % (num_classes_in_batch * num_gpus) == 0, (
'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).')
label_indices = {}
for idx, item in enumerate(dataset_items):
label = item['class_name']
label = item[label_key]
if label not in label_indices.keys():
label_indices[label] = [idx]
else:

View File

@ -8,6 +8,7 @@ from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.languages import get_language_balancer_weights
from TTS.tts.utils.speakers import get_speaker_balancer_weights
from TTS.encoder.utils.samplers import PerfectBatchSampler
# Fixing random state to avoid random fails
torch.manual_seed(0)
@ -82,3 +83,51 @@ class TestSamplers(unittest.TestCase):
spk2 += 1
assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"
def test_perfect_sampler(self): # pylint: disable=no-self-use
classes = set()
for item in train_samples:
classes.add(item["speaker_name"])
sampler = PerfectBatchSampler(
train_samples,
classes,
batch_size=2 * 3, # total batch size
num_classes_in_batch=2,
label_key="speaker_name",
shuffle=False,
drop_last=True)
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
for batch in batchs:
spk1, spk2 = 0, 0
# for in each batch
for index in batch:
if train_samples[index]["speaker_name"] == "ljspeech-0":
spk1 += 1
else:
spk2 += 1
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
def test_perfect_sampler_shuffle(self): # pylint: disable=no-self-use
classes = set()
for item in train_samples:
classes.add(item["speaker_name"])
sampler = PerfectBatchSampler(
train_samples,
classes,
batch_size=2 * 3, # total batch size
num_classes_in_batch=2,
label_key="speaker_name",
shuffle=True,
drop_last=False)
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
for batch in batchs:
spk1, spk2 = 0, 0
# for in each batch
for index in batch:
if train_samples[index]["speaker_name"] == "ljspeech-0":
spk1 += 1
else:
spk2 += 1
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"