mirror of https://github.com/coqui-ai/TTS.git
get_speaker_weighted_sampler
This commit is contained in:
parent
56480360cf
commit
9d2c445e3d
|
@ -12,7 +12,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from TTS.model import BaseModel
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
|
@ -334,6 +334,9 @@ class BaseTTS(BaseModel):
|
|||
if getattr(config, "use_language_weighted_sampler", False):
|
||||
print(" > Using Language weighted sampler")
|
||||
sampler = get_language_weighted_sampler(dataset.items)
|
||||
elif getattr(config, "use_speaker_weighted_sampler", False):
|
||||
print(" > Using Language weighted sampler")
|
||||
sampler = get_speaker_weighted_sampler(dataset.items)
|
||||
|
||||
|
||||
loader = DataLoader(
|
||||
|
|
|
@ -431,3 +431,12 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
|||
else:
|
||||
speaker_manager.save_speaker_ids_to_file(out_file_path)
|
||||
return speaker_manager
|
||||
|
||||
def get_speaker_weighted_sampler(items: list):
|
||||
speaker_names = np.array([item[2] for item in items])
|
||||
unique_speaker_names = np.unique(speaker_names).tolist()
|
||||
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
|
||||
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
|
||||
weight_speaker = 1. / speaker_count
|
||||
dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
|
||||
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
Loading…
Reference in New Issue