Add audio length sampler balancer (#1561)

* Add audio length sampler balancer

* Add unit tests
This commit is contained in:
Edresson Casanova 2022-05-12 14:59:19 -03:00 committed by GitHub
parent 6e460b7e42
commit c6008e5235
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 0 deletions

View File

@ -232,6 +232,14 @@ class BaseTTSConfig(BaseTrainingConfig):
language_weighted_sampler_alpha (float):
Number that control the influence of the language sampler weights. Defaults to ```1.0```.
use_length_weighted_sampler (bool):
Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided
into 10 buckets considering the min and max audio of the dataset. The sampler weights will be
computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```.
length_weighted_sampler_alpha (float):
Number that control the influence of the length sampler weights. Defaults to ```1.0```.
"""
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
@ -279,3 +287,5 @@ class BaseTTSConfig(BaseTrainingConfig):
speaker_weighted_sampler_alpha: float = 1.0
use_language_weighted_sampler: bool = False
language_weighted_sampler_alpha: float = 1.0
use_length_weighted_sampler: bool = False
length_weighted_sampler_alpha: float = 1.0

View File

@ -12,6 +12,7 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
@ -250,6 +251,14 @@ class BaseTTS(BaseTrainerModel):
else:
weights = get_speaker_balancer_weights(data_items) * alpha
if getattr(config, "use_length_weighted_sampler", False):
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
print(" > Using Length weighted sampler with alpha:", alpha)
if weights is not None:
weights += get_length_balancer_weights(data_items) * alpha
else:
weights = get_length_balancer_weights(data_items) * alpha
if weights is not None:
sampler = WeightedRandomSampler(weights, len(weights))
else:

View File

@ -1,4 +1,7 @@
import bisect
import numpy as np
import torch
def _pad_data(x, length):
@ -51,3 +54,26 @@ def prepare_stop_target(inputs, out_steps):
def pad_per_step(inputs, pad_len):
return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0)
def get_length_balancer_weights(items: list, num_buckets=10):
# get all durations
audio_lengths = np.array([item["audio_length"] for item in items])
# create the $num_buckets buckets classes based in the dataset max and min length
max_length = int(max(audio_lengths))
min_length = int(min(audio_lengths))
step = int((max_length - min_length) / num_buckets) + 1
buckets_classes = [i + step for i in range(min_length, (max_length - step) + num_buckets + 1, step)]
# add each sample in their respective length bucket
buckets_names = np.array(
[buckets_classes[bisect.bisect_left(buckets_classes, item["audio_length"])] for item in items]
)
# count and compute the weights_bucket for each sample
unique_buckets_names = np.unique(buckets_names).tolist()
bucket_ids = [unique_buckets_names.index(l) for l in buckets_names]
bucket_count = np.array([len(np.where(buckets_names == l)[0]) for l in unique_buckets_names])
weight_bucket = 1.0 / bucket_count
dataset_samples_weight = np.array([weight_bucket[l] for l in bucket_ids])
# normalize
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
return torch.from_numpy(dataset_samples_weight).float()

View File

@ -1,4 +1,5 @@
import functools
import random
import unittest
import torch
@ -6,6 +7,7 @@ import torch
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.encoder.utils.samplers import PerfectBatchSampler
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.languages import get_language_balancer_weights
from TTS.tts.utils.speakers import get_speaker_balancer_weights
@ -136,3 +138,28 @@ class TestSamplers(unittest.TestCase):
else:
spk2 += 1
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
def test_length_weighted_random_sampler(self): # pylint: disable=no-self-use
for _ in range(1000):
# gerenate a lenght unbalanced dataset with random max/min audio lenght
min_audio = random.randrange(1, 22050)
max_audio = random.randrange(44100, 220500)
for idx, item in enumerate(train_samples):
# increase the diversity of durations
random_increase = random.randrange(100, 1000)
if idx < 5:
item["audio_length"] = min_audio + random_increase
else:
item["audio_length"] = max_audio + random_increase
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
get_length_balancer_weights(train_samples, num_buckets=2), len(train_samples)
)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
len1, len2 = 0, 0
for index in ids:
if train_samples[index]["audio_length"] < max_audio:
len1 += 1
else:
len2 += 1
assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"