import random import numpy as np import torch from torch.utils.data import Dataset from TTS.speaker_encoder.utils.generic_utils import AugmentWAV, Storage class SpeakerEncoderDataset(Dataset): def __init__( self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64, storage_size=1, sample_from_storage_p=0.5, num_utter_per_speaker=10, skip_speakers=False, verbose=False, augmentation_config=None, ): """ Args: ap (TTS.tts.utils.AudioProcessor): audio processor object. meta_data (list): list of dataset instances. seq_len (int): voice segment length in seconds. verbose (bool): print diagnostic information. """ super().__init__() self.items = meta_data self.sample_rate = ap.sample_rate self.seq_len = int(voice_len * self.sample_rate) self.num_speakers_in_batch = num_speakers_in_batch self.num_utter_per_speaker = num_utter_per_speaker self.skip_speakers = skip_speakers self.ap = ap self.verbose = verbose self.__parse_items() storage_max_size = storage_size * num_speakers_in_batch self.storage = Storage( maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch ) self.sample_from_storage_p = float(sample_from_storage_p) speakers_aux = list(self.speakers) speakers_aux.sort() self.speakerid_to_classid = {key: i for i, key in enumerate(speakers_aux)} # Augmentation self.augmentator = None self.gaussian_augmentation_config = None if augmentation_config: self.data_augmentation_p = augmentation_config["p"] if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config): self.augmentator = AugmentWAV(ap, augmentation_config) if "gaussian" in augmentation_config.keys(): self.gaussian_augmentation_config = augmentation_config["gaussian"] if self.verbose: print("\n > DataLoader initialization") print(f" | > Speakers per Batch: {num_speakers_in_batch}") print(f" | > Storage Size: {storage_max_size} instances, each with {num_utter_per_speaker} utters") print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}") print(f" | > Number of instances : {len(self.items)}") print(f" | > Sequence length: {self.seq_len}") print(f" | > Num speakers: {len(self.speakers)}") def load_wav(self, filename): audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) return audio def load_data(self, idx): text, wav_file, speaker_name = self.items[idx] wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) mel = self.ap.melspectrogram(wav).astype("float32") # sample seq_len assert text.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx][1] sample = { "mel": mel, "item_idx": self.items[idx][1], "speaker_name": speaker_name, } return sample def __parse_items(self): self.speaker_to_utters = {} for i in self.items: path_ = i[1] speaker_ = i[2] if speaker_ in self.speaker_to_utters.keys(): self.speaker_to_utters[speaker_].append(path_) else: self.speaker_to_utters[speaker_] = [ path_, ] if self.skip_speakers: self.speaker_to_utters = { k: v for (k, v) in self.speaker_to_utters.items() if len(v) >= self.num_utter_per_speaker } self.speakers = [k for (k, v) in self.speaker_to_utters.items()] def __len__(self): return int(1e10) def get_num_speakers(self): return len(self.speakers) def __sample_speaker(self, ignore_speakers=None): speaker = random.sample(self.speakers, 1)[0] # if list of speakers_id is provide make sure that it's will be ignored if ignore_speakers and self.speakerid_to_classid[speaker] in ignore_speakers: while True: speaker = random.sample(self.speakers, 1)[0] if self.speakerid_to_classid[speaker] not in ignore_speakers: break if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]): utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker) else: utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker) return speaker, utters def __sample_speaker_utterances(self, speaker): """ Sample all M utterances for the given speaker. """ wavs = [] labels = [] for _ in range(self.num_utter_per_speaker): # TODO:dummy but works while True: # remove speakers that have num_utter less than 2 if len(self.speaker_to_utters[speaker]) > 1: utter = random.sample(self.speaker_to_utters[speaker], 1)[0] else: if speaker in self.speakers: self.speakers.remove(speaker) speaker, _ = self.__sample_speaker() continue wav = self.load_wav(utter) if wav.shape[0] - self.seq_len > 0: break if utter in self.speaker_to_utters[speaker]: self.speaker_to_utters[speaker].remove(utter) if self.augmentator is not None and self.data_augmentation_p: if random.random() < self.data_augmentation_p: wav = self.augmentator.apply_one(wav) wavs.append(wav) labels.append(self.speakerid_to_classid[speaker]) return wavs, labels def __getitem__(self, idx): speaker, _ = self.__sample_speaker() speaker_id = self.speakerid_to_classid[speaker] return speaker, speaker_id def __load_from_disk_and_storage(self, speaker): # don't sample from storage, but from HDD wavs_, labels_ = self.__sample_speaker_utterances(speaker) # put the newly loaded item into storage self.storage.append((wavs_, labels_)) return wavs_, labels_ def collate_fn(self, batch): # get the batch speaker_ids batch = np.array(batch) speakers_id_in_batch = set(batch[:, 1].astype(np.int32)) labels = [] feats = [] speakers = set() for speaker, speaker_id in batch: speaker_id = int(speaker_id) # ensure that an speaker appears only once in the batch if speaker_id in speakers: # remove current speaker if speaker_id in speakers_id_in_batch: speakers_id_in_batch.remove(speaker_id) speaker, _ = self.__sample_speaker(ignore_speakers=speakers_id_in_batch) speaker_id = self.speakerid_to_classid[speaker] speakers_id_in_batch.add(speaker_id) if random.random() < self.sample_from_storage_p and self.storage.full(): # sample from storage (if full) wavs_, labels_ = self.storage.get_random_sample_fast() # force choose the current speaker or other not in batch # It's necessary for ideal training with AngleProto and GE2E losses if labels_[0] in speakers_id_in_batch and labels_[0] != speaker_id: attempts = 0 while True: wavs_, labels_ = self.storage.get_random_sample_fast() if labels_[0] == speaker_id or labels_[0] not in speakers_id_in_batch: break attempts += 1 # Try 5 times after that load from disk if attempts >= 5: wavs_, labels_ = self.__load_from_disk_and_storage(speaker) break else: # don't sample from storage, but from HDD wavs_, labels_ = self.__load_from_disk_and_storage(speaker) # append speaker for control speakers.add(labels_[0]) # remove current speaker and append other if speaker_id in speakers_id_in_batch: speakers_id_in_batch.remove(speaker_id) speakers_id_in_batch.add(labels_[0]) # get a random subset of each of the wavs and extract mel spectrograms. feats_ = [] for wav in wavs_: offset = random.randint(0, wav.shape[0] - self.seq_len) wav = wav[offset : offset + self.seq_len] # add random gaussian noise if self.gaussian_augmentation_config and self.gaussian_augmentation_config["p"]: if random.random() < self.gaussian_augmentation_config["p"]: wav += np.random.normal( self.gaussian_augmentation_config["min_amplitude"], self.gaussian_augmentation_config["max_amplitude"], size=len(wav), ) mel = self.ap.melspectrogram(wav) feats_.append(torch.FloatTensor(mel)) labels.append(torch.LongTensor(labels_)) feats.extend(feats_) feats = torch.stack(feats) labels = torch.stack(labels) return feats.transpose(1, 2), labels