mirror of https://github.com/coqui-ai/TTS.git
Add argumnet to disable storage
This commit is contained in:
parent
984b6d9fd1
commit
8ba3385747
|
@ -83,7 +83,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
|
||||||
acc_avg = 0
|
acc_avg = 0
|
||||||
for key in class_acc_dict:
|
for key in class_acc_dict:
|
||||||
acc = sum(class_acc_dict[key])/len(class_acc_dict[key])
|
acc = sum(class_acc_dict[key])/len(class_acc_dict[key])
|
||||||
print("Class", key, "ACC:", acc)
|
print("Class", key, "Accuracy:", acc)
|
||||||
acc_avg += acc
|
acc_avg += acc
|
||||||
|
|
||||||
print("Average Acc:", acc_avg/len(class_acc_dict))
|
print("Average Accuracy:", acc_avg/len(class_acc_dict))
|
|
@ -41,11 +41,12 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
||||||
voice_len=c.voice_len,
|
voice_len=c.voice_len,
|
||||||
num_utter_per_class=c.num_utter_per_class,
|
num_utter_per_class=c.num_utter_per_class,
|
||||||
num_classes_in_batch=c.num_classes_in_batch,
|
num_classes_in_batch=c.num_classes_in_batch,
|
||||||
|
use_storage=c.use_storage,
|
||||||
skip_classes=c.skip_classes,
|
skip_classes=c.skip_classes,
|
||||||
storage_size=c.storage["storage_size"],
|
storage_size=c.storage["storage_size"],
|
||||||
sample_from_storage_p=c.storage["sample_from_storage_p"],
|
sample_from_storage_p=c.storage["sample_from_storage_p"],
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
augmentation_config=c.audio_augmentation,
|
augmentation_config=c.audio_augmentation if not is_val else None,
|
||||||
use_torch_spec=c.model_params.get("use_torch_spec", False),
|
use_torch_spec=c.model_params.get("use_torch_spec", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ class BaseEncoderConfig(BaseTrainingConfig):
|
||||||
|
|
||||||
audio_augmentation: Dict = field(default_factory=lambda: {})
|
audio_augmentation: Dict = field(default_factory=lambda: {})
|
||||||
|
|
||||||
|
use_storage: bool = False
|
||||||
storage: Dict = field(
|
storage: Dict = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage
|
"sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage
|
||||||
|
|
|
@ -14,6 +14,7 @@ class EncoderDataset(Dataset):
|
||||||
meta_data,
|
meta_data,
|
||||||
voice_len=1.6,
|
voice_len=1.6,
|
||||||
num_classes_in_batch=64,
|
num_classes_in_batch=64,
|
||||||
|
use_storage=False,
|
||||||
storage_size=1,
|
storage_size=1,
|
||||||
sample_from_storage_p=0.5,
|
sample_from_storage_p=0.5,
|
||||||
num_utter_per_class=10,
|
num_utter_per_class=10,
|
||||||
|
@ -36,16 +37,21 @@ class EncoderDataset(Dataset):
|
||||||
self.num_classes_in_batch = num_classes_in_batch
|
self.num_classes_in_batch = num_classes_in_batch
|
||||||
self.num_utter_per_class = num_utter_per_class
|
self.num_utter_per_class = num_utter_per_class
|
||||||
self.skip_classes = skip_classes
|
self.skip_classes = skip_classes
|
||||||
|
self.use_storage = use_storage
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.use_torch_spec = use_torch_spec
|
self.use_torch_spec = use_torch_spec
|
||||||
self.__parse_items()
|
self.__parse_items()
|
||||||
|
|
||||||
storage_max_size = storage_size * num_classes_in_batch
|
storage_max_size = storage_size * num_classes_in_batch
|
||||||
|
if self.use_storage:
|
||||||
self.storage = Storage(
|
self.storage = Storage(
|
||||||
maxsize=storage_max_size, storage_batchs=storage_size, num_classes_in_batch=num_classes_in_batch
|
maxsize=storage_max_size, storage_batchs=storage_size, num_classes_in_batch=num_classes_in_batch
|
||||||
)
|
)
|
||||||
self.sample_from_storage_p = float(sample_from_storage_p)
|
self.sample_from_storage_p = float(sample_from_storage_p)
|
||||||
|
else:
|
||||||
|
self.storage = None
|
||||||
|
self.sample_from_storage_p = None
|
||||||
|
|
||||||
classes_aux = list(self.classes)
|
classes_aux = list(self.classes)
|
||||||
classes_aux.sort()
|
classes_aux.sort()
|
||||||
|
@ -163,6 +169,7 @@ class EncoderDataset(Dataset):
|
||||||
# don't sample from storage, but from HDD
|
# don't sample from storage, but from HDD
|
||||||
wavs_, labels_ = self.__sample_class_utterances(class_name)
|
wavs_, labels_ = self.__sample_class_utterances(class_name)
|
||||||
# put the newly loaded item into storage
|
# put the newly loaded item into storage
|
||||||
|
if self.use_storage:
|
||||||
self.storage.append((wavs_, labels_))
|
self.storage.append((wavs_, labels_))
|
||||||
return wavs_, labels_
|
return wavs_, labels_
|
||||||
|
|
||||||
|
@ -189,7 +196,7 @@ class EncoderDataset(Dataset):
|
||||||
class_id = self.classname_to_classid[class_name]
|
class_id = self.classname_to_classid[class_name]
|
||||||
classes_id_in_batch.add(class_id)
|
classes_id_in_batch.add(class_id)
|
||||||
|
|
||||||
if random.random() < self.sample_from_storage_p and self.storage.full():
|
if self.use_storage and random.random() < self.sample_from_storage_p and self.storage.full():
|
||||||
# sample from storage (if full)
|
# sample from storage (if full)
|
||||||
wavs_, labels_ = self.storage.get_random_sample_fast()
|
wavs_, labels_ = self.storage.get_random_sample_fast()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue