import numpy as np import torch import random from torch.utils.data import Dataset class MyDataset(Dataset): def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64, num_utter_per_speaker=10, skip_speakers=False, verbose=False): """ Args: ap (TTS.tts.utils.AudioProcessor): audio processor object. meta_data (list): list of dataset instances. seq_len (int): voice segment length in seconds. verbose (bool): print diagnostic information. """ self.items = meta_data self.sample_rate = ap.sample_rate self.voice_len = voice_len self.seq_len = int(voice_len * self.sample_rate) self.num_speakers_in_batch = num_speakers_in_batch self.num_utter_per_speaker = num_utter_per_speaker self.skip_speakers = skip_speakers self.ap = ap self.verbose = verbose self.__parse_items() if self.verbose: print("\n > DataLoader initialization") print(f" | > Number of instances : {len(self.items)}") print(f" | > Sequence length: {self.seq_len}") print(f" | > Num speakers: {len(self.speakers)}") def load_wav(self, filename): audio = self.ap.load_wav(filename) return audio def load_data(self, idx): text, wav_file, speaker_name = self.items[idx] wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) mel = self.ap.melspectrogram(wav).astype("float32") # sample seq_len assert text.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx][1] sample = { "mel": mel, "item_idx": self.items[idx][1], "speaker_name": speaker_name, } return sample def __parse_items(self): """ Find unique speaker ids and create a dict mapping utterances from speaker id """ speakers = list({item[-1] for item in self.items}) self.speaker_to_utters = {} self.speakers = [] for speaker in speakers: speaker_utters = [item[1] for item in self.items if item[2] == speaker] if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers: print( f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}." ) else: self.speakers.append(speaker) self.speaker_to_utters[speaker] = speaker_utters def __len__(self): return int(1e10) def __sample_speaker(self): speaker = random.sample(self.speakers, 1)[0] if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]): utters = random.choices( self.speaker_to_utters[speaker], k=self.num_utter_per_speaker ) else: utters = random.sample( self.speaker_to_utters[speaker], self.num_utter_per_speaker ) return speaker, utters def __sample_speaker_utterances(self, speaker): """ Sample all M utterances for the given speaker. """ feats = [] labels = [] for _ in range(self.num_utter_per_speaker): # TODO:dummy but works while True: if len(self.speaker_to_utters[speaker]) > 0: utter = random.sample(self.speaker_to_utters[speaker], 1)[0] else: self.speakers.remove(speaker) speaker, _ = self.__sample_speaker() continue wav = self.load_wav(utter) if wav.shape[0] - self.seq_len > 0: break self.speaker_to_utters[speaker].remove(utter) offset = random.randint(0, wav.shape[0] - self.seq_len) mel = self.ap.melspectrogram(wav[offset : offset + self.seq_len]) feats.append(torch.FloatTensor(mel)) labels.append(speaker) return feats, labels def __getitem__(self, idx): speaker, _ = self.__sample_speaker() return speaker def collate_fn(self, batch): labels = [] feats = [] for speaker in batch: feats_, labels_ = self.__sample_speaker_utterances(speaker) labels.append(labels_) feats.extend(feats_) feats = torch.stack(feats) return feats.transpose(1, 2), labels