mirror of https://github.com/coqui-ai/TTS.git
Add Perfect Batch Sampler unit test
This commit is contained in:
parent
9c8b8201c3
commit
b0bad56ba9
|
@ -34,14 +34,14 @@ class PerfectBatchSampler(Sampler):
|
||||||
drop_last (bool): if True, drops last incomplete batch.
|
drop_last (bool): if True, drops last incomplete batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False):
|
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False, label_key="class_name"):
|
||||||
super().__init__(dataset_items)
|
super().__init__(dataset_items)
|
||||||
assert batch_size % (num_classes_in_batch * num_gpus) == 0, (
|
assert batch_size % (num_classes_in_batch * num_gpus) == 0, (
|
||||||
'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).')
|
'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).')
|
||||||
|
|
||||||
label_indices = {}
|
label_indices = {}
|
||||||
for idx, item in enumerate(dataset_items):
|
for idx, item in enumerate(dataset_items):
|
||||||
label = item['class_name']
|
label = item[label_key]
|
||||||
if label not in label_indices.keys():
|
if label not in label_indices.keys():
|
||||||
label_indices[label] = [idx]
|
label_indices[label] = [idx]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -8,6 +8,7 @@ from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.utils.languages import get_language_balancer_weights
|
from TTS.tts.utils.languages import get_language_balancer_weights
|
||||||
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
||||||
|
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
||||||
|
|
||||||
# Fixing random state to avoid random fails
|
# Fixing random state to avoid random fails
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
@ -82,3 +83,51 @@ class TestSamplers(unittest.TestCase):
|
||||||
spk2 += 1
|
spk2 += 1
|
||||||
|
|
||||||
assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"
|
assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"
|
||||||
|
|
||||||
|
def test_perfect_sampler(self): # pylint: disable=no-self-use
|
||||||
|
classes = set()
|
||||||
|
for item in train_samples:
|
||||||
|
classes.add(item["speaker_name"])
|
||||||
|
|
||||||
|
sampler = PerfectBatchSampler(
|
||||||
|
train_samples,
|
||||||
|
classes,
|
||||||
|
batch_size=2 * 3, # total batch size
|
||||||
|
num_classes_in_batch=2,
|
||||||
|
label_key="speaker_name",
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=True)
|
||||||
|
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
|
||||||
|
for batch in batchs:
|
||||||
|
spk1, spk2 = 0, 0
|
||||||
|
# for in each batch
|
||||||
|
for index in batch:
|
||||||
|
if train_samples[index]["speaker_name"] == "ljspeech-0":
|
||||||
|
spk1 += 1
|
||||||
|
else:
|
||||||
|
spk2 += 1
|
||||||
|
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
|
||||||
|
|
||||||
|
def test_perfect_sampler_shuffle(self): # pylint: disable=no-self-use
|
||||||
|
classes = set()
|
||||||
|
for item in train_samples:
|
||||||
|
classes.add(item["speaker_name"])
|
||||||
|
|
||||||
|
sampler = PerfectBatchSampler(
|
||||||
|
train_samples,
|
||||||
|
classes,
|
||||||
|
batch_size=2 * 3, # total batch size
|
||||||
|
num_classes_in_batch=2,
|
||||||
|
label_key="speaker_name",
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=False)
|
||||||
|
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
|
||||||
|
for batch in batchs:
|
||||||
|
spk1, spk2 = 0, 0
|
||||||
|
# for in each batch
|
||||||
|
for index in batch:
|
||||||
|
if train_samples[index]["speaker_name"] == "ljspeech-0":
|
||||||
|
spk1 += 1
|
||||||
|
else:
|
||||||
|
spk2 += 1
|
||||||
|
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
|
||||||
|
|
Loading…
Reference in New Issue