Add argumnet to disable storage

This commit is contained in:
Edresson Casanova 2022-03-04 07:57:37 -03:00
parent 984b6d9fd1
commit 8ba3385747
4 changed files with 18 additions and 9 deletions

View File

@ -83,7 +83,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
acc_avg = 0
for key in class_acc_dict:
acc = sum(class_acc_dict[key])/len(class_acc_dict[key])
print("Class", key, "ACC:", acc)
print("Class", key, "Accuracy:", acc)
acc_avg += acc
print("Average Acc:", acc_avg/len(class_acc_dict))
print("Average Accuracy:", acc_avg/len(class_acc_dict))

View File

@ -41,11 +41,12 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
voice_len=c.voice_len,
num_utter_per_class=c.num_utter_per_class,
num_classes_in_batch=c.num_classes_in_batch,
use_storage=c.use_storage,
skip_classes=c.skip_classes,
storage_size=c.storage["storage_size"],
sample_from_storage_p=c.storage["sample_from_storage_p"],
verbose=verbose,
augmentation_config=c.audio_augmentation,
augmentation_config=c.audio_augmentation if not is_val else None,
use_torch_spec=c.model_params.get("use_torch_spec", False),
)

View File

@ -27,6 +27,7 @@ class BaseEncoderConfig(BaseTrainingConfig):
audio_augmentation: Dict = field(default_factory=lambda: {})
use_storage: bool = False
storage: Dict = field(
default_factory=lambda: {
"sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage

View File

@ -14,6 +14,7 @@ class EncoderDataset(Dataset):
meta_data,
voice_len=1.6,
num_classes_in_batch=64,
use_storage=False,
storage_size=1,
sample_from_storage_p=0.5,
num_utter_per_class=10,
@ -36,16 +37,21 @@ class EncoderDataset(Dataset):
self.num_classes_in_batch = num_classes_in_batch
self.num_utter_per_class = num_utter_per_class
self.skip_classes = skip_classes
self.use_storage = use_storage
self.ap = ap
self.verbose = verbose
self.use_torch_spec = use_torch_spec
self.__parse_items()
storage_max_size = storage_size * num_classes_in_batch
self.storage = Storage(
maxsize=storage_max_size, storage_batchs=storage_size, num_classes_in_batch=num_classes_in_batch
)
self.sample_from_storage_p = float(sample_from_storage_p)
if self.use_storage:
self.storage = Storage(
maxsize=storage_max_size, storage_batchs=storage_size, num_classes_in_batch=num_classes_in_batch
)
self.sample_from_storage_p = float(sample_from_storage_p)
else:
self.storage = None
self.sample_from_storage_p = None
classes_aux = list(self.classes)
classes_aux.sort()
@ -163,7 +169,8 @@ class EncoderDataset(Dataset):
# don't sample from storage, but from HDD
wavs_, labels_ = self.__sample_class_utterances(class_name)
# put the newly loaded item into storage
self.storage.append((wavs_, labels_))
if self.use_storage:
self.storage.append((wavs_, labels_))
return wavs_, labels_
def collate_fn(self, batch):
@ -189,7 +196,7 @@ class EncoderDataset(Dataset):
class_id = self.classname_to_classid[class_name]
classes_id_in_batch.add(class_id)
if random.random() < self.sample_from_storage_p and self.storage.full():
if self.use_storage and random.random() < self.sample_from_storage_p and self.storage.full():
# sample from storage (if full)
wavs_, labels_ = self.storage.get_random_sample_fast()