From c6008e52356b46f556e70959d07bb8fb2af8084f Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 12 May 2022 14:59:19 -0300 Subject: [PATCH] Add audio length sampler balancer (#1561) * Add audio length sampler balancer * Add unit tests --- TTS/tts/configs/shared_configs.py | 10 ++++++++++ TTS/tts/models/base_tts.py | 9 +++++++++ TTS/tts/utils/data.py | 26 ++++++++++++++++++++++++++ tests/data_tests/test_samplers.py | 27 +++++++++++++++++++++++++++ 4 files changed, 72 insertions(+) diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index dcc862e8..b782117c 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -232,6 +232,14 @@ class BaseTTSConfig(BaseTrainingConfig): language_weighted_sampler_alpha (float): Number that control the influence of the language sampler weights. Defaults to ```1.0```. + + use_length_weighted_sampler (bool): + Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided + into 10 buckets considering the min and max audio of the dataset. The sampler weights will be + computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```. + + length_weighted_sampler_alpha (float): + Number that control the influence of the length sampler weights. Defaults to ```1.0```. """ audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) @@ -279,3 +287,5 @@ class BaseTTSConfig(BaseTrainingConfig): speaker_weighted_sampler_alpha: float = 1.0 use_language_weighted_sampler: bool = False language_weighted_sampler_alpha: float = 1.0 + use_length_weighted_sampler: bool = False + length_weighted_sampler_alpha: float = 1.0 diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 652b77dd..c71872d3 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -12,6 +12,7 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset +from TTS.tts.utils.data import get_length_balancer_weights from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager from TTS.tts.utils.synthesis import synthesis @@ -250,6 +251,14 @@ class BaseTTS(BaseTrainerModel): else: weights = get_speaker_balancer_weights(data_items) * alpha + if getattr(config, "use_length_weighted_sampler", False): + alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) + print(" > Using Length weighted sampler with alpha:", alpha) + if weights is not None: + weights += get_length_balancer_weights(data_items) * alpha + else: + weights = get_length_balancer_weights(data_items) * alpha + if weights is not None: sampler = WeightedRandomSampler(weights, len(weights)) else: diff --git a/TTS/tts/utils/data.py b/TTS/tts/utils/data.py index b0d88740..22e46b68 100644 --- a/TTS/tts/utils/data.py +++ b/TTS/tts/utils/data.py @@ -1,4 +1,7 @@ +import bisect + import numpy as np +import torch def _pad_data(x, length): @@ -51,3 +54,26 @@ def prepare_stop_target(inputs, out_steps): def pad_per_step(inputs, pad_len): return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0) + + +def get_length_balancer_weights(items: list, num_buckets=10): + # get all durations + audio_lengths = np.array([item["audio_length"] for item in items]) + # create the $num_buckets buckets classes based in the dataset max and min length + max_length = int(max(audio_lengths)) + min_length = int(min(audio_lengths)) + step = int((max_length - min_length) / num_buckets) + 1 + buckets_classes = [i + step for i in range(min_length, (max_length - step) + num_buckets + 1, step)] + # add each sample in their respective length bucket + buckets_names = np.array( + [buckets_classes[bisect.bisect_left(buckets_classes, item["audio_length"])] for item in items] + ) + # count and compute the weights_bucket for each sample + unique_buckets_names = np.unique(buckets_names).tolist() + bucket_ids = [unique_buckets_names.index(l) for l in buckets_names] + bucket_count = np.array([len(np.where(buckets_names == l)[0]) for l in unique_buckets_names]) + weight_bucket = 1.0 / bucket_count + dataset_samples_weight = np.array([weight_bucket[l] for l in bucket_ids]) + # normalize + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + return torch.from_numpy(dataset_samples_weight).float() diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py index 42f1bfd5..b85e0ec4 100644 --- a/tests/data_tests/test_samplers.py +++ b/tests/data_tests/test_samplers.py @@ -1,4 +1,5 @@ import functools +import random import unittest import torch @@ -6,6 +7,7 @@ import torch from TTS.config.shared_configs import BaseDatasetConfig from TTS.encoder.utils.samplers import PerfectBatchSampler from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.data import get_length_balancer_weights from TTS.tts.utils.languages import get_language_balancer_weights from TTS.tts.utils.speakers import get_speaker_balancer_weights @@ -136,3 +138,28 @@ class TestSamplers(unittest.TestCase): else: spk2 += 1 assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced" + + def test_length_weighted_random_sampler(self): # pylint: disable=no-self-use + for _ in range(1000): + # gerenate a lenght unbalanced dataset with random max/min audio lenght + min_audio = random.randrange(1, 22050) + max_audio = random.randrange(44100, 220500) + for idx, item in enumerate(train_samples): + # increase the diversity of durations + random_increase = random.randrange(100, 1000) + if idx < 5: + item["audio_length"] = min_audio + random_increase + else: + item["audio_length"] = max_audio + random_increase + + weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler( + get_length_balancer_weights(train_samples, num_buckets=2), len(train_samples) + ) + ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) + len1, len2 = 0, 0 + for index in ids: + if train_samples[index]["audio_length"] < max_audio: + len1 += 1 + else: + len2 += 1 + assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"