mirror of https://github.com/coqui-ai/TTS.git
Add alphas to control language and speaker balancer (#1216)
* Add alphas to control language and speaker balancer * Add docs for speaker and language samplers * Change the Samplers weights to float for save memory * Change the test_samplers to unittest format * Add get_sampler method in BaseTTS * Fix rebase issues * Add language and speaker samplers support for DDP training * Rename distributed sampler wrapper * Remove the DistributedSamplerWrapper and use the one from Trainer * Bugfix after rebase * Move the samplers config to tts config
This commit is contained in:
parent
f381e29b91
commit
917f417ac4
|
@ -258,4 +258,3 @@ class BaseTrainingConfig(TrainerConfig):
|
||||||
num_loader_workers: int = 0
|
num_loader_workers: int = 0
|
||||||
num_eval_loader_workers: int = 0
|
num_eval_loader_workers: int = 0
|
||||||
use_noise_augment: bool = False
|
use_noise_augment: bool = False
|
||||||
use_language_weighted_sampler: bool = False
|
|
||||||
|
|
|
@ -220,6 +220,18 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
eval_split_size (float):
|
eval_split_size (float):
|
||||||
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
|
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
|
||||||
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
|
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
|
||||||
|
|
||||||
|
use_speaker_weighted_sampler (bool):
|
||||||
|
Enable / Disable the batch balancer by speaker. Defaults to ```False```.
|
||||||
|
|
||||||
|
speaker_weighted_sampler_alpha (float):
|
||||||
|
Number that control the influence of the speaker sampler weights. Defaults to ```1.0```.
|
||||||
|
|
||||||
|
use_language_weighted_sampler (bool):
|
||||||
|
Enable / Disable the batch balancer by language. Defaults to ```False```.
|
||||||
|
|
||||||
|
language_weighted_sampler_alpha (float):
|
||||||
|
Number that control the influence of the language sampler weights. Defaults to ```1.0```.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||||
|
@ -262,3 +274,8 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
# evaluation
|
# evaluation
|
||||||
eval_split_max_size: int = None
|
eval_split_max_size: int = None
|
||||||
eval_split_size: float = 0.01
|
eval_split_size: float = 0.01
|
||||||
|
# weighted samplers
|
||||||
|
use_speaker_weighted_sampler: bool = False
|
||||||
|
speaker_weighted_sampler_alpha: float = 1.0
|
||||||
|
use_language_weighted_sampler: bool = False
|
||||||
|
language_weighted_sampler_alpha: float = 1.0
|
||||||
|
|
|
@ -7,14 +7,15 @@ import torch.distributed as dist
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||||
|
|
||||||
from TTS.model import BaseTrainerModel
|
from TTS.model import BaseTrainerModel
|
||||||
from TTS.tts.datasets.dataset import TTSDataset
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_balancer_weights
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
from torch.utils.data.sampler import WeightedRandomSampler
|
||||||
|
|
||||||
# pylint: skip-file
|
# pylint: skip-file
|
||||||
|
|
||||||
|
@ -232,6 +233,36 @@ class BaseTTS(BaseTrainerModel):
|
||||||
"language_ids": language_ids,
|
"language_ids": language_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
|
||||||
|
weights = None
|
||||||
|
data_items = dataset.samples
|
||||||
|
|
||||||
|
if getattr(config, "use_language_weighted_sampler", False):
|
||||||
|
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
|
||||||
|
print(" > Using Language weighted sampler with alpha:", alpha)
|
||||||
|
weights = get_language_balancer_weights(data_items) * alpha
|
||||||
|
|
||||||
|
if getattr(config, "use_speaker_weighted_sampler", False):
|
||||||
|
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
|
||||||
|
print(" > Using Speaker weighted sampler with alpha:", alpha)
|
||||||
|
if weights is not None:
|
||||||
|
weights += get_speaker_balancer_weights(data_items) * alpha
|
||||||
|
else:
|
||||||
|
weights = get_speaker_balancer_weights(data_items) * alpha
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
sampler = WeightedRandomSampler(weights, len(weights))
|
||||||
|
else:
|
||||||
|
sampler = None
|
||||||
|
|
||||||
|
# sampler for DDP
|
||||||
|
if sampler is None:
|
||||||
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
else: # If a sampler is already defined use this sampler and DDP sampler together
|
||||||
|
sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler
|
||||||
|
|
||||||
|
return sampler
|
||||||
|
|
||||||
def get_data_loader(
|
def get_data_loader(
|
||||||
self,
|
self,
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
|
@ -300,25 +331,8 @@ class BaseTTS(BaseTrainerModel):
|
||||||
# sort input sequences from short to long
|
# sort input sequences from short to long
|
||||||
dataset.preprocess_samples()
|
dataset.preprocess_samples()
|
||||||
|
|
||||||
# sampler for DDP
|
# get samplers
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = self.get_sampler(config, dataset, num_gpus)
|
||||||
|
|
||||||
# Weighted samplers
|
|
||||||
# TODO: make this DDP amenable
|
|
||||||
assert not (
|
|
||||||
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
|
|
||||||
), "language_weighted_sampler is not supported with DistributedSampler"
|
|
||||||
assert not (
|
|
||||||
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
|
|
||||||
), "speaker_weighted_sampler is not supported with DistributedSampler"
|
|
||||||
|
|
||||||
if sampler is None:
|
|
||||||
if getattr(config, "use_language_weighted_sampler", False):
|
|
||||||
print(" > Using Language weighted sampler")
|
|
||||||
sampler = get_language_weighted_sampler(dataset.samples)
|
|
||||||
elif getattr(config, "use_speaker_weighted_sampler", False):
|
|
||||||
print(" > Using Language weighted sampler")
|
|
||||||
sampler = get_speaker_weighted_sampler(dataset.samples)
|
|
||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
|
|
@ -13,7 +13,6 @@ from torch import nn
|
||||||
from torch.cuda.amp.autocast_mode import autocast
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||||
|
@ -24,8 +23,8 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
|
||||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||||
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
|
@ -1354,31 +1353,15 @@ class Vits(BaseTTS):
|
||||||
# sort input sequences from short to long
|
# sort input sequences from short to long
|
||||||
dataset.preprocess_samples()
|
dataset.preprocess_samples()
|
||||||
|
|
||||||
# sampler for DDP
|
# get samplers
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = self.get_sampler(config, dataset, num_gpus)
|
||||||
|
|
||||||
# Weighted samplers
|
|
||||||
# TODO: make this DDP amenable
|
|
||||||
assert not (
|
|
||||||
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
|
|
||||||
), "language_weighted_sampler is not supported with DistributedSampler"
|
|
||||||
assert not (
|
|
||||||
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
|
|
||||||
), "speaker_weighted_sampler is not supported with DistributedSampler"
|
|
||||||
|
|
||||||
if sampler is None:
|
|
||||||
if getattr(config, "use_language_weighted_sampler", False):
|
|
||||||
print(" > Using Language weighted sampler")
|
|
||||||
sampler = get_language_weighted_sampler(dataset.samples)
|
|
||||||
elif getattr(config, "use_speaker_weighted_sampler", False):
|
|
||||||
print(" > Using Language weighted sampler")
|
|
||||||
sampler = get_speaker_weighted_sampler(dataset.samples)
|
|
||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||||
shuffle=False, # shuffle is done in the dataset.
|
shuffle=False, # shuffle is done in the dataset.
|
||||||
drop_last=False, # setting this False might cause issues in AMP training.
|
drop_last=False, # setting this False might cause issues in AMP training.
|
||||||
|
sampler=sampler,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
|
|
|
@ -6,7 +6,6 @@ import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch.utils.data.sampler import WeightedRandomSampler
|
|
||||||
|
|
||||||
from TTS.config import check_config_and_model_args
|
from TTS.config import check_config_and_model_args
|
||||||
|
|
||||||
|
@ -128,11 +127,14 @@ def _set_file_path(path):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_language_weighted_sampler(items: list):
|
def get_language_balancer_weights(items: list):
|
||||||
language_names = np.array([item["language"] for item in items])
|
language_names = np.array([item["language"] for item in items])
|
||||||
unique_language_names = np.unique(language_names).tolist()
|
unique_language_names = np.unique(language_names).tolist()
|
||||||
language_ids = [unique_language_names.index(l) for l in language_names]
|
language_ids = [unique_language_names.index(l) for l in language_names]
|
||||||
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
|
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
|
||||||
weight_language = 1.0 / language_count
|
weight_language = 1.0 / language_count
|
||||||
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
|
# get weight for each sample
|
||||||
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
dataset_samples_weight = np.array([weight_language[l] for l in language_ids])
|
||||||
|
# normalize
|
||||||
|
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||||
|
return torch.from_numpy(dataset_samples_weight).float()
|
||||||
|
|
|
@ -7,7 +7,6 @@ import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch.utils.data.sampler import WeightedRandomSampler
|
|
||||||
|
|
||||||
from TTS.config import get_from_config_or_model_args_with_default, load_config
|
from TTS.config import get_from_config_or_model_args_with_default, load_config
|
||||||
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||||
|
@ -449,11 +448,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
||||||
return speaker_manager
|
return speaker_manager
|
||||||
|
|
||||||
|
|
||||||
def get_speaker_weighted_sampler(items: list):
|
def get_speaker_balancer_weights(items: list):
|
||||||
speaker_names = np.array([item["speaker_name"] for item in items])
|
speaker_names = np.array([item["speaker_name"] for item in items])
|
||||||
unique_speaker_names = np.unique(speaker_names).tolist()
|
unique_speaker_names = np.unique(speaker_names).tolist()
|
||||||
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
|
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
|
||||||
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
|
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
|
||||||
weight_speaker = 1.0 / speaker_count
|
weight_speaker = 1.0 / speaker_count
|
||||||
dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
|
dataset_samples_weight = np.array([weight_speaker[l] for l in speaker_ids])
|
||||||
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
# normalize
|
||||||
|
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||||
|
return torch.from_numpy(dataset_samples_weight).float()
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.utils.languages import get_language_weighted_sampler
|
from TTS.tts.utils.languages import get_language_balancer_weights
|
||||||
|
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
||||||
|
|
||||||
# Fixing random state to avoid random fails
|
# Fixing random state to avoid random fails
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
@ -25,34 +28,57 @@ dataset_config_pt = BaseDatasetConfig(
|
||||||
language="pt-br",
|
language="pt-br",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Adding the EN samples twice to create an unbalanced dataset
|
# Adding the EN samples twice to create a language unbalanced dataset
|
||||||
train_samples, eval_samples = load_tts_samples(
|
train_samples, eval_samples = load_tts_samples(
|
||||||
[dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
|
[dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# gerenate a speaker unbalanced dataset
|
||||||
|
for i, sample in enumerate(train_samples):
|
||||||
|
if i < 5:
|
||||||
|
sample["speaker_name"] = "ljspeech-0"
|
||||||
|
else:
|
||||||
|
sample["speaker_name"] = "ljspeech-1"
|
||||||
|
|
||||||
|
|
||||||
def is_balanced(lang_1, lang_2):
|
def is_balanced(lang_1, lang_2):
|
||||||
return 0.85 < lang_1 / lang_2 < 1.2
|
return 0.85 < lang_1 / lang_2 < 1.2
|
||||||
|
|
||||||
|
|
||||||
random_sampler = torch.utils.data.RandomSampler(train_samples)
|
class TestSamplers(unittest.TestCase):
|
||||||
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
|
def test_language_random_sampler(self): # pylint: disable=no-self-use
|
||||||
en, pt = 0, 0
|
random_sampler = torch.utils.data.RandomSampler(train_samples)
|
||||||
for index in ids:
|
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
|
||||||
if train_samples[index]["language"] == "en":
|
en, pt = 0, 0
|
||||||
en += 1
|
for index in ids:
|
||||||
else:
|
if train_samples[index]["language"] == "en":
|
||||||
pt += 1
|
en += 1
|
||||||
|
else:
|
||||||
|
pt += 1
|
||||||
|
|
||||||
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
|
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
|
||||||
|
|
||||||
weighted_sampler = get_language_weighted_sampler(train_samples)
|
def test_language_weighted_random_sampler(self): # pylint: disable=no-self-use
|
||||||
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_language_balancer_weights(train_samples), len(train_samples))
|
||||||
en, pt = 0, 0
|
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
||||||
for index in ids:
|
en, pt = 0, 0
|
||||||
if train_samples[index]["language"] == "en":
|
for index in ids:
|
||||||
en += 1
|
if train_samples[index]["language"] == "en":
|
||||||
else:
|
en += 1
|
||||||
pt += 1
|
else:
|
||||||
|
pt += 1
|
||||||
|
|
||||||
assert is_balanced(en, pt), "Weighted sampler is supposed to be balanced"
|
assert is_balanced(en, pt), "Language Weighted sampler is supposed to be balanced"
|
||||||
|
|
||||||
|
def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use
|
||||||
|
|
||||||
|
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_speaker_balancer_weights(train_samples), len(train_samples))
|
||||||
|
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
||||||
|
spk1, spk2 = 0, 0
|
||||||
|
for index in ids:
|
||||||
|
if train_samples[index]["speaker_name"] == "ljspeech-0":
|
||||||
|
spk1 += 1
|
||||||
|
else:
|
||||||
|
spk2 += 1
|
||||||
|
|
||||||
|
assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"
|
||||||
|
|
|
@ -45,7 +45,7 @@ config = VitsConfig(
|
||||||
["Be a voice, not an echo.", "ljspeech-0", None, "en"],
|
["Be a voice, not an echo.", "ljspeech-0", None, "en"],
|
||||||
["Be a voice, not an echo.", "ljspeech-1", None, "pt-br"],
|
["Be a voice, not an echo.", "ljspeech-1", None, "pt-br"],
|
||||||
],
|
],
|
||||||
datasets=[dataset_config_en, dataset_config_pt],
|
datasets=[dataset_config_en, dataset_config_en, dataset_config_en, dataset_config_pt],
|
||||||
)
|
)
|
||||||
# set audio config
|
# set audio config
|
||||||
config.audio.do_trim_silence = True
|
config.audio.do_trim_silence = True
|
||||||
|
@ -71,8 +71,11 @@ config.d_vector_dim = 256
|
||||||
config.model_args.use_sdp = True
|
config.model_args.use_sdp = True
|
||||||
config.use_sdp = True
|
config.use_sdp = True
|
||||||
|
|
||||||
# deactivate language sampler
|
# activate language and speaker samplers
|
||||||
config.use_language_weighted_sampler = False
|
config.use_language_weighted_sampler = True
|
||||||
|
config.language_weighted_sampler_alpha = 10
|
||||||
|
config.use_speaker_weighted_sampler = True
|
||||||
|
config.speaker_weighted_sampler_alpha = 5
|
||||||
|
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue