From 005bba60b018a804a0752ec77973b11aba70ab4b Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Sun, 19 Sep 2021 23:34:38 +0200 Subject: [PATCH] get_speaker_weighted_sampler --- TTS/tts/models/base_tts.py | 5 ++++- TTS/tts/utils/speakers.py | 9 +++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index c03a7df5..9d722222 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -12,7 +12,7 @@ from torch.utils.data.distributed import DistributedSampler from TTS.model import BaseModel from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.datasets.dataset import TTSDataset -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text import make_symbols @@ -334,6 +334,9 @@ class BaseTTS(BaseModel): if getattr(config, "use_language_weighted_sampler", False): print(" > Using Language weighted sampler") sampler = get_language_weighted_sampler(dataset.items) + elif getattr(config, "use_speaker_weighted_sampler", False): + print(" > Using Language weighted sampler") + sampler = get_speaker_weighted_sampler(dataset.items) loader = DataLoader( diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 282875af..8ccbdafc 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -431,3 +431,12 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, else: speaker_manager.save_speaker_ids_to_file(out_file_path) return speaker_manager + +def get_speaker_weighted_sampler(items: list): + speaker_names = np.array([item[2] for item in items]) + unique_speaker_names = np.unique(speaker_names).tolist() + speaker_ids = [unique_speaker_names.index(l) for l in speaker_names] + speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names]) + weight_speaker = 1. / speaker_count + dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double() + return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight)) \ No newline at end of file