Add audio length sampler balancer

This commit is contained in:
Edresson Casanova 2022-05-07 14:42:58 -03:00
parent 27cf388a79
commit e86e3d2e87
3 changed files with 45 additions and 0 deletions

View File

@ -232,6 +232,14 @@ class BaseTTSConfig(BaseTrainingConfig):
language_weighted_sampler_alpha (float):
Number that control the influence of the language sampler weights. Defaults to ```1.0```.
use_length_weighted_sampler (bool):
Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided
into 10 buckets considering the min and max audio of the dataset. The sampler weights will be
computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```.
length_weighted_sampler_alpha (float):
Number that control the influence of the length sampler weights. Defaults to ```1.0```.
"""
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
@ -279,3 +287,5 @@ class BaseTTSConfig(BaseTrainingConfig):
speaker_weighted_sampler_alpha: float = 1.0
use_language_weighted_sampler: bool = False
language_weighted_sampler_alpha: float = 1.0
use_length_weighted_sampler: bool = False
length_weighted_sampler_alpha: float = 1.0

View File

@ -12,6 +12,7 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
@ -250,6 +251,14 @@ class BaseTTS(BaseTrainerModel):
else:
weights = get_speaker_balancer_weights(data_items) * alpha
if getattr(config, "use_length_weighted_sampler", False):
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
print(" > Using Length weighted sampler with alpha:", alpha)
if weights is not None:
weights += get_length_balancer_weights(data_items) * alpha
else:
weights = get_length_balancer_weights(data_items) * alpha
if weights is not None:
sampler = WeightedRandomSampler(weights, len(weights))
else:

View File

@ -1,4 +1,7 @@
import bisect
import numpy as np
import torch
def _pad_data(x, length):
@ -51,3 +54,26 @@ def prepare_stop_target(inputs, out_steps):
def pad_per_step(inputs, pad_len):
return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0)
def get_length_balancer_weights(items: list, num_buckets=10):
# get all durations
audio_lengths = np.array([item["audio_length"] for item in items])
# create the $num_buckets buckets classes based in the dataset max and min length
max_length = int(max(audio_lengths))
min_length = int(min(audio_lengths))
step = int((max_length - min_length) / num_buckets)
buckets_classes = [i + step for i in range(min_length, max_length, step)]
# add each sample in their respective length bucket
buckets_names = np.array(
[buckets_classes[bisect.bisect_left(buckets_classes, item["audio_length"])] for item in items]
)
# count and compute the weights_bucket for each sample
unique_buckets_names = np.unique(buckets_names).tolist()
bucket_ids = [unique_buckets_names.index(l) for l in buckets_names]
bucket_count = np.array([len(np.where(buckets_names == l)[0]) for l in unique_buckets_names])
weight_bucket = 1.0 / bucket_count
dataset_samples_weight = np.array([weight_bucket[l] for l in bucket_ids])
# normalize
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
return torch.from_numpy(dataset_samples_weight).float()