mirror of https://github.com/coqui-ai/TTS.git
get_speaker_weighted_sampler
This commit is contained in:
parent
56480360cf
commit
9d2c445e3d
|
@ -12,7 +12,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||||
from TTS.model import BaseModel
|
from TTS.model import BaseModel
|
||||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||||
from TTS.tts.datasets.dataset import TTSDataset
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
|
||||||
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text import make_symbols
|
from TTS.tts.utils.text import make_symbols
|
||||||
|
@ -334,6 +334,9 @@ class BaseTTS(BaseModel):
|
||||||
if getattr(config, "use_language_weighted_sampler", False):
|
if getattr(config, "use_language_weighted_sampler", False):
|
||||||
print(" > Using Language weighted sampler")
|
print(" > Using Language weighted sampler")
|
||||||
sampler = get_language_weighted_sampler(dataset.items)
|
sampler = get_language_weighted_sampler(dataset.items)
|
||||||
|
elif getattr(config, "use_speaker_weighted_sampler", False):
|
||||||
|
print(" > Using Language weighted sampler")
|
||||||
|
sampler = get_speaker_weighted_sampler(dataset.items)
|
||||||
|
|
||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
|
|
|
@ -431,3 +431,12 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
||||||
else:
|
else:
|
||||||
speaker_manager.save_speaker_ids_to_file(out_file_path)
|
speaker_manager.save_speaker_ids_to_file(out_file_path)
|
||||||
return speaker_manager
|
return speaker_manager
|
||||||
|
|
||||||
|
def get_speaker_weighted_sampler(items: list):
|
||||||
|
speaker_names = np.array([item[2] for item in items])
|
||||||
|
unique_speaker_names = np.unique(speaker_names).tolist()
|
||||||
|
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
|
||||||
|
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
|
||||||
|
weight_speaker = 1. / speaker_count
|
||||||
|
dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
|
||||||
|
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
Loading…
Reference in New Issue