mirror of https://github.com/coqui-ai/TTS.git
Add Perfect Batch Sampler unit test
This commit is contained in:
parent
9c8b8201c3
commit
b0bad56ba9
|
@ -34,14 +34,14 @@ class PerfectBatchSampler(Sampler):
|
|||
drop_last (bool): if True, drops last incomplete batch.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False):
|
||||
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False, label_key="class_name"):
|
||||
super().__init__(dataset_items)
|
||||
assert batch_size % (num_classes_in_batch * num_gpus) == 0, (
|
||||
'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).')
|
||||
|
||||
label_indices = {}
|
||||
for idx, item in enumerate(dataset_items):
|
||||
label = item['class_name']
|
||||
label = item[label_key]
|
||||
if label not in label_indices.keys():
|
||||
label_indices[label] = [idx]
|
||||
else:
|
||||
|
|
|
@ -8,6 +8,7 @@ from TTS.config.shared_configs import BaseDatasetConfig
|
|||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.languages import get_language_balancer_weights
|
||||
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
||||
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
||||
|
||||
# Fixing random state to avoid random fails
|
||||
torch.manual_seed(0)
|
||||
|
@ -82,3 +83,51 @@ class TestSamplers(unittest.TestCase):
|
|||
spk2 += 1
|
||||
|
||||
assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"
|
||||
|
||||
def test_perfect_sampler(self): # pylint: disable=no-self-use
|
||||
classes = set()
|
||||
for item in train_samples:
|
||||
classes.add(item["speaker_name"])
|
||||
|
||||
sampler = PerfectBatchSampler(
|
||||
train_samples,
|
||||
classes,
|
||||
batch_size=2 * 3, # total batch size
|
||||
num_classes_in_batch=2,
|
||||
label_key="speaker_name",
|
||||
shuffle=False,
|
||||
drop_last=True)
|
||||
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
|
||||
for batch in batchs:
|
||||
spk1, spk2 = 0, 0
|
||||
# for in each batch
|
||||
for index in batch:
|
||||
if train_samples[index]["speaker_name"] == "ljspeech-0":
|
||||
spk1 += 1
|
||||
else:
|
||||
spk2 += 1
|
||||
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
|
||||
|
||||
def test_perfect_sampler_shuffle(self): # pylint: disable=no-self-use
|
||||
classes = set()
|
||||
for item in train_samples:
|
||||
classes.add(item["speaker_name"])
|
||||
|
||||
sampler = PerfectBatchSampler(
|
||||
train_samples,
|
||||
classes,
|
||||
batch_size=2 * 3, # total batch size
|
||||
num_classes_in_batch=2,
|
||||
label_key="speaker_name",
|
||||
shuffle=True,
|
||||
drop_last=False)
|
||||
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
|
||||
for batch in batchs:
|
||||
spk1, spk2 = 0, 0
|
||||
# for in each batch
|
||||
for index in batch:
|
||||
if train_samples[index]["speaker_name"] == "ljspeech-0":
|
||||
spk1 += 1
|
||||
else:
|
||||
spk2 += 1
|
||||
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
|
||||
|
|
Loading…
Reference in New Issue