diff --git a/TTS/speaker_encoder/dataset.py b/TTS/speaker_encoder/dataset.py index e67dd716..ad6b95e9 100644 --- a/TTS/speaker_encoder/dataset.py +++ b/TTS/speaker_encoder/dataset.py @@ -51,21 +51,37 @@ class MyDataset(Dataset): return sample def __parse_items(self): - """ - Find unique speaker ids and create a dict mapping utterances from speaker id - """ - speakers = list({item[-1] for item in self.items}) self.speaker_to_utters = {} - self.speakers = [] - for speaker in speakers: - speaker_utters = [item[1] for item in self.items if item[2] == speaker] - if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers: - print( - f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}." - ) + for i in self.items: + path_ = i[1] + speaker_ = i[2] + if speaker_ in self.speaker_to_utters.keys(): + self.speaker_to_utters[speaker_].append(path_) else: - self.speakers.append(speaker) - self.speaker_to_utters[speaker] = speaker_utters + self.speaker_to_utters[speaker_] = [path_, ] + + if self.skip_speakers: + self.speaker_to_utters = {k: v for (k, v) in self.speaker_to_utters.items() if + len(v) >= self.num_utter_per_speaker} + + self.speakers = [k for (k, v) in self.speaker_to_utters] + + # def __parse_items(self): + # """ + # Find unique speaker ids and create a dict mapping utterances from speaker id + # """ + # speakers = list({item[-1] for item in self.items}) + # self.speaker_to_utters = {} + # self.speakers = [] + # for speaker in speakers: + # speaker_utters = [item[1] for item in self.items if item[2] == speaker] + # if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers: + # print( + # f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}." + # ) + # else: + # self.speakers.append(speaker) + # self.speaker_to_utters[speaker] = speaker_utters def __len__(self): return int(1e10)