mirror of https://github.com/coqui-ai/TTS.git
Add audio length sampler balancer
This commit is contained in:
parent
27cf388a79
commit
e86e3d2e87
|
@ -232,6 +232,14 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
|
|
||||||
language_weighted_sampler_alpha (float):
|
language_weighted_sampler_alpha (float):
|
||||||
Number that control the influence of the language sampler weights. Defaults to ```1.0```.
|
Number that control the influence of the language sampler weights. Defaults to ```1.0```.
|
||||||
|
|
||||||
|
use_length_weighted_sampler (bool):
|
||||||
|
Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided
|
||||||
|
into 10 buckets considering the min and max audio of the dataset. The sampler weights will be
|
||||||
|
computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```.
|
||||||
|
|
||||||
|
length_weighted_sampler_alpha (float):
|
||||||
|
Number that control the influence of the length sampler weights. Defaults to ```1.0```.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||||
|
@ -279,3 +287,5 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
speaker_weighted_sampler_alpha: float = 1.0
|
speaker_weighted_sampler_alpha: float = 1.0
|
||||||
use_language_weighted_sampler: bool = False
|
use_language_weighted_sampler: bool = False
|
||||||
language_weighted_sampler_alpha: float = 1.0
|
language_weighted_sampler_alpha: float = 1.0
|
||||||
|
use_length_weighted_sampler: bool = False
|
||||||
|
length_weighted_sampler_alpha: float = 1.0
|
||||||
|
|
|
@ -12,6 +12,7 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||||
|
|
||||||
from TTS.model import BaseTrainerModel
|
from TTS.model import BaseTrainerModel
|
||||||
from TTS.tts.datasets.dataset import TTSDataset
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
|
from TTS.tts.utils.data import get_length_balancer_weights
|
||||||
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
|
@ -250,6 +251,14 @@ class BaseTTS(BaseTrainerModel):
|
||||||
else:
|
else:
|
||||||
weights = get_speaker_balancer_weights(data_items) * alpha
|
weights = get_speaker_balancer_weights(data_items) * alpha
|
||||||
|
|
||||||
|
if getattr(config, "use_length_weighted_sampler", False):
|
||||||
|
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
|
||||||
|
print(" > Using Length weighted sampler with alpha:", alpha)
|
||||||
|
if weights is not None:
|
||||||
|
weights += get_length_balancer_weights(data_items) * alpha
|
||||||
|
else:
|
||||||
|
weights = get_length_balancer_weights(data_items) * alpha
|
||||||
|
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
sampler = WeightedRandomSampler(weights, len(weights))
|
sampler = WeightedRandomSampler(weights, len(weights))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
|
import bisect
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def _pad_data(x, length):
|
def _pad_data(x, length):
|
||||||
|
@ -51,3 +54,26 @@ def prepare_stop_target(inputs, out_steps):
|
||||||
|
|
||||||
def pad_per_step(inputs, pad_len):
|
def pad_per_step(inputs, pad_len):
|
||||||
return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0)
|
return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_length_balancer_weights(items: list, num_buckets=10):
|
||||||
|
# get all durations
|
||||||
|
audio_lengths = np.array([item["audio_length"] for item in items])
|
||||||
|
# create the $num_buckets buckets classes based in the dataset max and min length
|
||||||
|
max_length = int(max(audio_lengths))
|
||||||
|
min_length = int(min(audio_lengths))
|
||||||
|
step = int((max_length - min_length) / num_buckets)
|
||||||
|
buckets_classes = [i + step for i in range(min_length, max_length, step)]
|
||||||
|
# add each sample in their respective length bucket
|
||||||
|
buckets_names = np.array(
|
||||||
|
[buckets_classes[bisect.bisect_left(buckets_classes, item["audio_length"])] for item in items]
|
||||||
|
)
|
||||||
|
# count and compute the weights_bucket for each sample
|
||||||
|
unique_buckets_names = np.unique(buckets_names).tolist()
|
||||||
|
bucket_ids = [unique_buckets_names.index(l) for l in buckets_names]
|
||||||
|
bucket_count = np.array([len(np.where(buckets_names == l)[0]) for l in unique_buckets_names])
|
||||||
|
weight_bucket = 1.0 / bucket_count
|
||||||
|
dataset_samples_weight = np.array([weight_bucket[l] for l in bucket_ids])
|
||||||
|
# normalize
|
||||||
|
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||||
|
return torch.from_numpy(dataset_samples_weight).float()
|
||||||
|
|
Loading…
Reference in New Issue