get_speaker_weighted_sampler

This commit is contained in:
WeberJulian 2021-09-19 23:34:38 +02:00 committed by Eren Gölge
parent 56480360cf
commit 9d2c445e3d
2 changed files with 13 additions and 1 deletions

View File

@ -12,7 +12,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols
@ -334,6 +334,9 @@ class BaseTTS(BaseModel):
if getattr(config, "use_language_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_language_weighted_sampler(dataset.items)
elif getattr(config, "use_speaker_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_speaker_weighted_sampler(dataset.items)
loader = DataLoader(

View File

@ -431,3 +431,12 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
else:
speaker_manager.save_speaker_ids_to_file(out_file_path)
return speaker_manager
def get_speaker_weighted_sampler(items: list):
speaker_names = np.array([item[2] for item in items])
unique_speaker_names = np.unique(speaker_names).tolist()
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
weight_speaker = 1. / speaker_count
dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))