From 8ba33857473b8b352f6477e4280e526d571a6a03 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 4 Mar 2022 07:57:37 -0300 Subject: [PATCH] Add argumnet to disable storage --- TTS/bin/eval_encoder.py | 4 ++-- TTS/bin/train_encoder.py | 3 ++- TTS/encoder/configs/base_encoder_config.py | 1 + TTS/encoder/dataset.py | 19 +++++++++++++------ 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/TTS/bin/eval_encoder.py b/TTS/bin/eval_encoder.py index 8acc8ffc..0b1af9f2 100644 --- a/TTS/bin/eval_encoder.py +++ b/TTS/bin/eval_encoder.py @@ -83,7 +83,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)): acc_avg = 0 for key in class_acc_dict: acc = sum(class_acc_dict[key])/len(class_acc_dict[key]) - print("Class", key, "ACC:", acc) + print("Class", key, "Accuracy:", acc) acc_avg += acc -print("Average Acc:", acc_avg/len(class_acc_dict)) +print("Average Accuracy:", acc_avg/len(class_acc_dict)) \ No newline at end of file diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 181ea1e0..c65474db 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -41,11 +41,12 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False voice_len=c.voice_len, num_utter_per_class=c.num_utter_per_class, num_classes_in_batch=c.num_classes_in_batch, + use_storage=c.use_storage, skip_classes=c.skip_classes, storage_size=c.storage["storage_size"], sample_from_storage_p=c.storage["sample_from_storage_p"], verbose=verbose, - augmentation_config=c.audio_augmentation, + augmentation_config=c.audio_augmentation if not is_val else None, use_torch_spec=c.model_params.get("use_torch_spec", False), ) diff --git a/TTS/encoder/configs/base_encoder_config.py b/TTS/encoder/configs/base_encoder_config.py index 5005a47f..838f9300 100644 --- a/TTS/encoder/configs/base_encoder_config.py +++ b/TTS/encoder/configs/base_encoder_config.py @@ -27,6 +27,7 @@ class BaseEncoderConfig(BaseTrainingConfig): audio_augmentation: Dict = field(default_factory=lambda: {}) + use_storage: bool = False storage: Dict = field( default_factory=lambda: { "sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage diff --git a/TTS/encoder/dataset.py b/TTS/encoder/dataset.py index 2c777a6a..474aa0c2 100644 --- a/TTS/encoder/dataset.py +++ b/TTS/encoder/dataset.py @@ -14,6 +14,7 @@ class EncoderDataset(Dataset): meta_data, voice_len=1.6, num_classes_in_batch=64, + use_storage=False, storage_size=1, sample_from_storage_p=0.5, num_utter_per_class=10, @@ -36,16 +37,21 @@ class EncoderDataset(Dataset): self.num_classes_in_batch = num_classes_in_batch self.num_utter_per_class = num_utter_per_class self.skip_classes = skip_classes + self.use_storage = use_storage self.ap = ap self.verbose = verbose self.use_torch_spec = use_torch_spec self.__parse_items() storage_max_size = storage_size * num_classes_in_batch - self.storage = Storage( - maxsize=storage_max_size, storage_batchs=storage_size, num_classes_in_batch=num_classes_in_batch - ) - self.sample_from_storage_p = float(sample_from_storage_p) + if self.use_storage: + self.storage = Storage( + maxsize=storage_max_size, storage_batchs=storage_size, num_classes_in_batch=num_classes_in_batch + ) + self.sample_from_storage_p = float(sample_from_storage_p) + else: + self.storage = None + self.sample_from_storage_p = None classes_aux = list(self.classes) classes_aux.sort() @@ -163,7 +169,8 @@ class EncoderDataset(Dataset): # don't sample from storage, but from HDD wavs_, labels_ = self.__sample_class_utterances(class_name) # put the newly loaded item into storage - self.storage.append((wavs_, labels_)) + if self.use_storage: + self.storage.append((wavs_, labels_)) return wavs_, labels_ def collate_fn(self, batch): @@ -189,7 +196,7 @@ class EncoderDataset(Dataset): class_id = self.classname_to_classid[class_name] classes_id_in_batch.add(class_id) - if random.random() < self.sample_from_storage_p and self.storage.full(): + if self.use_storage and random.random() < self.sample_from_storage_p and self.storage.full(): # sample from storage (if full) wavs_, labels_ = self.storage.get_random_sample_fast()