mirror of https://github.com/coqui-ai/TTS.git
Add audio length sampler balancer (#1561)
* Add audio length sampler balancer * Add unit tests
This commit is contained in:
parent
6e460b7e42
commit
c6008e5235
|
@ -232,6 +232,14 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
|
||||
language_weighted_sampler_alpha (float):
|
||||
Number that control the influence of the language sampler weights. Defaults to ```1.0```.
|
||||
|
||||
use_length_weighted_sampler (bool):
|
||||
Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided
|
||||
into 10 buckets considering the min and max audio of the dataset. The sampler weights will be
|
||||
computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```.
|
||||
|
||||
length_weighted_sampler_alpha (float):
|
||||
Number that control the influence of the length sampler weights. Defaults to ```1.0```.
|
||||
"""
|
||||
|
||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
|
@ -279,3 +287,5 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
speaker_weighted_sampler_alpha: float = 1.0
|
||||
use_language_weighted_sampler: bool = False
|
||||
language_weighted_sampler_alpha: float = 1.0
|
||||
use_length_weighted_sampler: bool = False
|
||||
length_weighted_sampler_alpha: float = 1.0
|
||||
|
|
|
@ -12,6 +12,7 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
|||
|
||||
from TTS.model import BaseTrainerModel
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.data import get_length_balancer_weights
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
|
@ -250,6 +251,14 @@ class BaseTTS(BaseTrainerModel):
|
|||
else:
|
||||
weights = get_speaker_balancer_weights(data_items) * alpha
|
||||
|
||||
if getattr(config, "use_length_weighted_sampler", False):
|
||||
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Length weighted sampler with alpha:", alpha)
|
||||
if weights is not None:
|
||||
weights += get_length_balancer_weights(data_items) * alpha
|
||||
else:
|
||||
weights = get_length_balancer_weights(data_items) * alpha
|
||||
|
||||
if weights is not None:
|
||||
sampler = WeightedRandomSampler(weights, len(weights))
|
||||
else:
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
import bisect
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def _pad_data(x, length):
|
||||
|
@ -51,3 +54,26 @@ def prepare_stop_target(inputs, out_steps):
|
|||
|
||||
def pad_per_step(inputs, pad_len):
|
||||
return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0)
|
||||
|
||||
|
||||
def get_length_balancer_weights(items: list, num_buckets=10):
|
||||
# get all durations
|
||||
audio_lengths = np.array([item["audio_length"] for item in items])
|
||||
# create the $num_buckets buckets classes based in the dataset max and min length
|
||||
max_length = int(max(audio_lengths))
|
||||
min_length = int(min(audio_lengths))
|
||||
step = int((max_length - min_length) / num_buckets) + 1
|
||||
buckets_classes = [i + step for i in range(min_length, (max_length - step) + num_buckets + 1, step)]
|
||||
# add each sample in their respective length bucket
|
||||
buckets_names = np.array(
|
||||
[buckets_classes[bisect.bisect_left(buckets_classes, item["audio_length"])] for item in items]
|
||||
)
|
||||
# count and compute the weights_bucket for each sample
|
||||
unique_buckets_names = np.unique(buckets_names).tolist()
|
||||
bucket_ids = [unique_buckets_names.index(l) for l in buckets_names]
|
||||
bucket_count = np.array([len(np.where(buckets_names == l)[0]) for l in unique_buckets_names])
|
||||
weight_bucket = 1.0 / bucket_count
|
||||
dataset_samples_weight = np.array([weight_bucket[l] for l in bucket_ids])
|
||||
# normalize
|
||||
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||
return torch.from_numpy(dataset_samples_weight).float()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import functools
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
@ -6,6 +7,7 @@ import torch
|
|||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.data import get_length_balancer_weights
|
||||
from TTS.tts.utils.languages import get_language_balancer_weights
|
||||
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
||||
|
||||
|
@ -136,3 +138,28 @@ class TestSamplers(unittest.TestCase):
|
|||
else:
|
||||
spk2 += 1
|
||||
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
|
||||
|
||||
def test_length_weighted_random_sampler(self): # pylint: disable=no-self-use
|
||||
for _ in range(1000):
|
||||
# gerenate a lenght unbalanced dataset with random max/min audio lenght
|
||||
min_audio = random.randrange(1, 22050)
|
||||
max_audio = random.randrange(44100, 220500)
|
||||
for idx, item in enumerate(train_samples):
|
||||
# increase the diversity of durations
|
||||
random_increase = random.randrange(100, 1000)
|
||||
if idx < 5:
|
||||
item["audio_length"] = min_audio + random_increase
|
||||
else:
|
||||
item["audio_length"] = max_audio + random_increase
|
||||
|
||||
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
|
||||
get_length_balancer_weights(train_samples, num_buckets=2), len(train_samples)
|
||||
)
|
||||
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
||||
len1, len2 = 0, 0
|
||||
for index in ids:
|
||||
if train_samples[index]["audio_length"] < max_audio:
|
||||
len1 += 1
|
||||
else:
|
||||
len2 += 1
|
||||
assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"
|
||||
|
|
Loading…
Reference in New Issue