mirror of https://github.com/coqui-ai/TTS.git
Add argumnet to disable storage
This commit is contained in:
parent
984b6d9fd1
commit
8ba3385747
|
@ -83,7 +83,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
|
|||
acc_avg = 0
|
||||
for key in class_acc_dict:
|
||||
acc = sum(class_acc_dict[key])/len(class_acc_dict[key])
|
||||
print("Class", key, "ACC:", acc)
|
||||
print("Class", key, "Accuracy:", acc)
|
||||
acc_avg += acc
|
||||
|
||||
print("Average Acc:", acc_avg/len(class_acc_dict))
|
||||
print("Average Accuracy:", acc_avg/len(class_acc_dict))
|
|
@ -41,11 +41,12 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
|||
voice_len=c.voice_len,
|
||||
num_utter_per_class=c.num_utter_per_class,
|
||||
num_classes_in_batch=c.num_classes_in_batch,
|
||||
use_storage=c.use_storage,
|
||||
skip_classes=c.skip_classes,
|
||||
storage_size=c.storage["storage_size"],
|
||||
sample_from_storage_p=c.storage["sample_from_storage_p"],
|
||||
verbose=verbose,
|
||||
augmentation_config=c.audio_augmentation,
|
||||
augmentation_config=c.audio_augmentation if not is_val else None,
|
||||
use_torch_spec=c.model_params.get("use_torch_spec", False),
|
||||
)
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ class BaseEncoderConfig(BaseTrainingConfig):
|
|||
|
||||
audio_augmentation: Dict = field(default_factory=lambda: {})
|
||||
|
||||
use_storage: bool = False
|
||||
storage: Dict = field(
|
||||
default_factory=lambda: {
|
||||
"sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage
|
||||
|
|
|
@ -14,6 +14,7 @@ class EncoderDataset(Dataset):
|
|||
meta_data,
|
||||
voice_len=1.6,
|
||||
num_classes_in_batch=64,
|
||||
use_storage=False,
|
||||
storage_size=1,
|
||||
sample_from_storage_p=0.5,
|
||||
num_utter_per_class=10,
|
||||
|
@ -36,16 +37,21 @@ class EncoderDataset(Dataset):
|
|||
self.num_classes_in_batch = num_classes_in_batch
|
||||
self.num_utter_per_class = num_utter_per_class
|
||||
self.skip_classes = skip_classes
|
||||
self.use_storage = use_storage
|
||||
self.ap = ap
|
||||
self.verbose = verbose
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.__parse_items()
|
||||
|
||||
storage_max_size = storage_size * num_classes_in_batch
|
||||
self.storage = Storage(
|
||||
maxsize=storage_max_size, storage_batchs=storage_size, num_classes_in_batch=num_classes_in_batch
|
||||
)
|
||||
self.sample_from_storage_p = float(sample_from_storage_p)
|
||||
if self.use_storage:
|
||||
self.storage = Storage(
|
||||
maxsize=storage_max_size, storage_batchs=storage_size, num_classes_in_batch=num_classes_in_batch
|
||||
)
|
||||
self.sample_from_storage_p = float(sample_from_storage_p)
|
||||
else:
|
||||
self.storage = None
|
||||
self.sample_from_storage_p = None
|
||||
|
||||
classes_aux = list(self.classes)
|
||||
classes_aux.sort()
|
||||
|
@ -163,7 +169,8 @@ class EncoderDataset(Dataset):
|
|||
# don't sample from storage, but from HDD
|
||||
wavs_, labels_ = self.__sample_class_utterances(class_name)
|
||||
# put the newly loaded item into storage
|
||||
self.storage.append((wavs_, labels_))
|
||||
if self.use_storage:
|
||||
self.storage.append((wavs_, labels_))
|
||||
return wavs_, labels_
|
||||
|
||||
def collate_fn(self, batch):
|
||||
|
@ -189,7 +196,7 @@ class EncoderDataset(Dataset):
|
|||
class_id = self.classname_to_classid[class_name]
|
||||
classes_id_in_batch.add(class_id)
|
||||
|
||||
if random.random() < self.sample_from_storage_p and self.storage.full():
|
||||
if self.use_storage and random.random() < self.sample_from_storage_p and self.storage.full():
|
||||
# sample from storage (if full)
|
||||
wavs_, labels_ = self.storage.get_random_sample_fast()
|
||||
|
||||
|
|
Loading…
Reference in New Issue