refactor: move shared function into dataset.py

This commit is contained in:
Enno Hermann 2024-11-22 22:33:25 +01:00
parent 54f4228a46
commit b1ac884e07
3 changed files with 27 additions and 45 deletions

View File

@ -63,6 +63,31 @@ def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int:
raise RuntimeError(msg) from e
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: Optional[dict] = None):
"""Create inverse frequency weights for balancing the dataset.
Use `multi_dict` to scale relative weights."""
attr_names_samples = np.array([item[attr_name] for item in items])
unique_attr_names = np.unique(attr_names_samples).tolist()
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
weight_attr = 1.0 / attr_count
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
if multi_dict is not None:
# check if all keys are in the multi_dict
for k in multi_dict:
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
# scale weights
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
dataset_samples_weight *= multiplier_samples
return (
torch.from_numpy(dataset_samples_weight).float(),
unique_attr_names,
np.unique(dataset_samples_weight).tolist(),
)
class TTSDataset(Dataset):
def __init__(
self,

View File

@ -17,7 +17,7 @@ from trainer.io import load_fsspec
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample, get_attribute_balancer_weights
from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel
from TTS.tts.layers.losses import (
ForwardSumLoss,
@ -194,25 +194,6 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
##############################
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
"""Create balancer weight for torch WeightedSampler"""
attr_names_samples = np.array([item[attr_name] for item in items])
unique_attr_names = np.unique(attr_names_samples).tolist()
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
weight_attr = 1.0 / attr_count
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
if multi_dict is not None:
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
dataset_samples_weight *= multiplier_samples
return (
torch.from_numpy(dataset_samples_weight).float(),
unique_attr_names,
np.unique(dataset_samples_weight).tolist(),
)
class ForwardTTSE2eF0Dataset(F0Dataset):
"""Override F0Dataset to avoid slow computing of pitches"""

View File

@ -21,7 +21,7 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, get_attribute_balancer_weights
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
@ -219,30 +219,6 @@ class VitsAudioConfig(Coqpit):
##############################
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
"""Create inverse frequency weights for balancing the dataset.
Use `multi_dict` to scale relative weights."""
attr_names_samples = np.array([item[attr_name] for item in items])
unique_attr_names = np.unique(attr_names_samples).tolist()
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
weight_attr = 1.0 / attr_count
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
if multi_dict is not None:
# check if all keys are in the multi_dict
for k in multi_dict:
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
# scale weights
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
dataset_samples_weight *= multiplier_samples
return (
torch.from_numpy(dataset_samples_weight).float(),
unique_attr_names,
np.unique(dataset_samples_weight).tolist(),
)
class VitsDataset(TTSDataset):
def __init__(self, model_args, *args, **kwargs):
super().__init__(*args, **kwargs)