mirror of https://github.com/coqui-ai/TTS.git
refactor: move shared function into dataset.py
This commit is contained in:
parent
54f4228a46
commit
b1ac884e07
|
@ -63,6 +63,31 @@ def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int:
|
||||||
raise RuntimeError(msg) from e
|
raise RuntimeError(msg) from e
|
||||||
|
|
||||||
|
|
||||||
|
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: Optional[dict] = None):
|
||||||
|
"""Create inverse frequency weights for balancing the dataset.
|
||||||
|
|
||||||
|
Use `multi_dict` to scale relative weights."""
|
||||||
|
attr_names_samples = np.array([item[attr_name] for item in items])
|
||||||
|
unique_attr_names = np.unique(attr_names_samples).tolist()
|
||||||
|
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
|
||||||
|
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
|
||||||
|
weight_attr = 1.0 / attr_count
|
||||||
|
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
|
||||||
|
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||||
|
if multi_dict is not None:
|
||||||
|
# check if all keys are in the multi_dict
|
||||||
|
for k in multi_dict:
|
||||||
|
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
|
||||||
|
# scale weights
|
||||||
|
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
|
||||||
|
dataset_samples_weight *= multiplier_samples
|
||||||
|
return (
|
||||||
|
torch.from_numpy(dataset_samples_weight).float(),
|
||||||
|
unique_attr_names,
|
||||||
|
np.unique(dataset_samples_weight).tolist(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TTSDataset(Dataset):
|
class TTSDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -17,7 +17,7 @@ from trainer.io import load_fsspec
|
||||||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample
|
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample, get_attribute_balancer_weights
|
||||||
from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel
|
from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel
|
||||||
from TTS.tts.layers.losses import (
|
from TTS.tts.layers.losses import (
|
||||||
ForwardSumLoss,
|
ForwardSumLoss,
|
||||||
|
@ -194,25 +194,6 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
|
||||||
##############################
|
##############################
|
||||||
|
|
||||||
|
|
||||||
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
|
|
||||||
"""Create balancer weight for torch WeightedSampler"""
|
|
||||||
attr_names_samples = np.array([item[attr_name] for item in items])
|
|
||||||
unique_attr_names = np.unique(attr_names_samples).tolist()
|
|
||||||
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
|
|
||||||
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
|
|
||||||
weight_attr = 1.0 / attr_count
|
|
||||||
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
|
|
||||||
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
|
||||||
if multi_dict is not None:
|
|
||||||
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
|
|
||||||
dataset_samples_weight *= multiplier_samples
|
|
||||||
return (
|
|
||||||
torch.from_numpy(dataset_samples_weight).float(),
|
|
||||||
unique_attr_names,
|
|
||||||
np.unique(dataset_samples_weight).tolist(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ForwardTTSE2eF0Dataset(F0Dataset):
|
class ForwardTTSE2eF0Dataset(F0Dataset):
|
||||||
"""Override F0Dataset to avoid slow computing of pitches"""
|
"""Override F0Dataset to avoid slow computing of pitches"""
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||||
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, get_attribute_balancer_weights
|
||||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||||
|
@ -219,30 +219,6 @@ class VitsAudioConfig(Coqpit):
|
||||||
##############################
|
##############################
|
||||||
|
|
||||||
|
|
||||||
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
|
|
||||||
"""Create inverse frequency weights for balancing the dataset.
|
|
||||||
Use `multi_dict` to scale relative weights."""
|
|
||||||
attr_names_samples = np.array([item[attr_name] for item in items])
|
|
||||||
unique_attr_names = np.unique(attr_names_samples).tolist()
|
|
||||||
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
|
|
||||||
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
|
|
||||||
weight_attr = 1.0 / attr_count
|
|
||||||
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
|
|
||||||
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
|
||||||
if multi_dict is not None:
|
|
||||||
# check if all keys are in the multi_dict
|
|
||||||
for k in multi_dict:
|
|
||||||
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
|
|
||||||
# scale weights
|
|
||||||
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
|
|
||||||
dataset_samples_weight *= multiplier_samples
|
|
||||||
return (
|
|
||||||
torch.from_numpy(dataset_samples_weight).float(),
|
|
||||||
unique_attr_names,
|
|
||||||
np.unique(dataset_samples_weight).tolist(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VitsDataset(TTSDataset):
|
class VitsDataset(TTSDataset):
|
||||||
def __init__(self, model_args, *args, **kwargs):
|
def __init__(self, model_args, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
Loading…
Reference in New Issue