add: save wavs instead feats to storage.

This is done in order to mitigate staleness when caching and loading from data storage
This commit is contained in:
mueller 2020-09-17 14:14:30 +02:00
parent 1511076fde
commit e36a3067e4
3 changed files with 14 additions and 11 deletions

View File

@ -111,7 +111,7 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
"lr": current_lr,
"grad_norm": grad_norm,
"step_time": step_time,
"loader_time": loader_time
"avg_loader_time": avg_loader_time
}
tb_logger.tb_train_epoch_stats(global_step, train_stats)
figures = {

View File

@ -110,7 +110,7 @@ class MyDataset(Dataset):
"""
Sample all M utterances for the given speaker.
"""
feats = []
wavs = []
labels = []
for _ in range(self.num_utter_per_speaker):
# TODO:dummy but works
@ -126,11 +126,9 @@ class MyDataset(Dataset):
break
self.speaker_to_utters[speaker].remove(utter)
offset = random.randint(0, wav.shape[0] - self.seq_len)
mel = self.ap.melspectrogram(wav[offset : offset + self.seq_len])
feats.append(torch.FloatTensor(mel))
wavs.append(wav)
labels.append(speaker)
return feats, labels
return wavs, labels
def __getitem__(self, idx):
speaker, _ = self.__sample_speaker()
@ -142,15 +140,21 @@ class MyDataset(Dataset):
for speaker in batch:
if random.random() < self.sample_from_storage_p and self.storage.full():
# sample from storage (if full), ignoring the speaker
feats_, labels_ = random.choice(self.storage.queue)
wavs_, labels_ = random.choice(self.storage.queue)
else:
# don't sample from storage, but from HDD
feats_, labels_ = self.__sample_speaker_utterances(speaker)
wavs_, labels_ = self.__sample_speaker_utterances(speaker)
# if storage is full, remove an item
if self.storage.full():
_ = self.storage.get_nowait()
# put the newly loaded item into storage
self.storage.put_nowait((feats_, labels_))
self.storage.put_nowait((wavs_, labels_))
# get a random subset of each of the wavs and convert to MFCC.
offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_]
mels_ = [self.ap.melspectrogram(wavs_[i][offsets_[i]: offsets_[i] + self.seq_len]) for i in range(len(wavs_))]
feats_ = [torch.FloatTensor(mel) for mel in mels_]
labels.append(labels_)
feats.extend(feats_)
feats = torch.stack(feats)

View File

@ -17,10 +17,9 @@ def load_meta_data(datasets):
root_path = dataset['path']
meta_file_train = dataset['meta_file_train']
meta_file_val = dataset['meta_file_val']
print(f" | > Preprocessing {name}")
preprocessor = get_preprocessor_by_name(name)
meta_data_train = preprocessor(root_path, meta_file_train)
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
if meta_file_val is None:
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
else: