mirror of https://github.com/coqui-ai/TTS.git
refactor: move shared function into dataset.py
This commit is contained in:
parent
54f4228a46
commit
b1ac884e07
|
@ -63,6 +63,31 @@ def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int:
|
|||
raise RuntimeError(msg) from e
|
||||
|
||||
|
||||
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: Optional[dict] = None):
|
||||
"""Create inverse frequency weights for balancing the dataset.
|
||||
|
||||
Use `multi_dict` to scale relative weights."""
|
||||
attr_names_samples = np.array([item[attr_name] for item in items])
|
||||
unique_attr_names = np.unique(attr_names_samples).tolist()
|
||||
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
|
||||
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
|
||||
weight_attr = 1.0 / attr_count
|
||||
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
|
||||
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||
if multi_dict is not None:
|
||||
# check if all keys are in the multi_dict
|
||||
for k in multi_dict:
|
||||
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
|
||||
# scale weights
|
||||
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
|
||||
dataset_samples_weight *= multiplier_samples
|
||||
return (
|
||||
torch.from_numpy(dataset_samples_weight).float(),
|
||||
unique_attr_names,
|
||||
np.unique(dataset_samples_weight).tolist(),
|
||||
)
|
||||
|
||||
|
||||
class TTSDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -17,7 +17,7 @@ from trainer.io import load_fsspec
|
|||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample
|
||||
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample, get_attribute_balancer_weights
|
||||
from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel
|
||||
from TTS.tts.layers.losses import (
|
||||
ForwardSumLoss,
|
||||
|
@ -194,25 +194,6 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
|
|||
##############################
|
||||
|
||||
|
||||
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
|
||||
"""Create balancer weight for torch WeightedSampler"""
|
||||
attr_names_samples = np.array([item[attr_name] for item in items])
|
||||
unique_attr_names = np.unique(attr_names_samples).tolist()
|
||||
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
|
||||
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
|
||||
weight_attr = 1.0 / attr_count
|
||||
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
|
||||
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||
if multi_dict is not None:
|
||||
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
|
||||
dataset_samples_weight *= multiplier_samples
|
||||
return (
|
||||
torch.from_numpy(dataset_samples_weight).float(),
|
||||
unique_attr_names,
|
||||
np.unique(dataset_samples_weight).tolist(),
|
||||
)
|
||||
|
||||
|
||||
class ForwardTTSE2eF0Dataset(F0Dataset):
|
||||
"""Override F0Dataset to avoid slow computing of pitches"""
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
|||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
||||
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, get_attribute_balancer_weights
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
|
@ -219,30 +219,6 @@ class VitsAudioConfig(Coqpit):
|
|||
##############################
|
||||
|
||||
|
||||
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
|
||||
"""Create inverse frequency weights for balancing the dataset.
|
||||
Use `multi_dict` to scale relative weights."""
|
||||
attr_names_samples = np.array([item[attr_name] for item in items])
|
||||
unique_attr_names = np.unique(attr_names_samples).tolist()
|
||||
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
|
||||
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
|
||||
weight_attr = 1.0 / attr_count
|
||||
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
|
||||
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||
if multi_dict is not None:
|
||||
# check if all keys are in the multi_dict
|
||||
for k in multi_dict:
|
||||
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
|
||||
# scale weights
|
||||
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
|
||||
dataset_samples_weight *= multiplier_samples
|
||||
return (
|
||||
torch.from_numpy(dataset_samples_weight).float(),
|
||||
unique_attr_names,
|
||||
np.unique(dataset_samples_weight).tolist(),
|
||||
)
|
||||
|
||||
|
||||
class VitsDataset(TTSDataset):
|
||||
def __init__(self, model_args, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue