From 711a46506ff117e03b909882cd15bcea1a005f74 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 4 Mar 2022 18:02:08 -0300 Subject: [PATCH] Fix lint checks --- TTS/bin/eval_encoder.py | 22 +++++++------- TTS/bin/train_encoder.py | 48 +++++++++++++----------------- TTS/encoder/dataset.py | 15 ++++------ TTS/encoder/utils/generic_utils.py | 45 ---------------------------- TTS/encoder/utils/samplers.py | 22 +++++++------- 5 files changed, 49 insertions(+), 103 deletions(-) diff --git a/TTS/bin/eval_encoder.py b/TTS/bin/eval_encoder.py index 9a4e0204..78d42049 100644 --- a/TTS/bin/eval_encoder.py +++ b/TTS/bin/eval_encoder.py @@ -1,5 +1,4 @@ import argparse -import os import torch from argparse import RawTextHelpFormatter @@ -45,7 +44,7 @@ speaker_manager = SpeakerManager( if speaker_manager.speaker_encoder_config.map_classid_to_classname is not None: map_classid_to_classname = speaker_manager.speaker_encoder_config.map_classid_to_classname -else: +else: map_classid_to_classname = None # compute speaker embeddings @@ -69,20 +68,19 @@ for idx, wav_file in enumerate(tqdm(wav_files)): predicted_label = map_classid_to_classname[str(class_id)] else: predicted_label = None - + if class_name is not None and predicted_label is not None: - is_equal = int(class_name == predicted_label) - if class_name not in class_acc_dict: - class_acc_dict[class_name] = [is_equal] - else: - class_acc_dict[class_name].append(is_equal) + is_equal = int(class_name == predicted_label) + if class_name not in class_acc_dict: + class_acc_dict[class_name] = [is_equal] + else: + class_acc_dict[class_name].append(is_equal) else: - print("Error: class_name or/and predicted_label are None") - exit() + raise RuntimeError("Error: class_name or/and predicted_label are None") acc_avg = 0 -for key in class_acc_dict: - acc = sum(class_acc_dict[key])/len(class_acc_dict[key]) +for key, values in class_acc_dict.items(): + acc = sum(values)/len(values) print("Class", key, "Accuracy:", acc) acc_avg += acc diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 6095713a..caf86a42 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -12,7 +12,7 @@ from trainer.torch import NoamLR from TTS.encoder.dataset import EncoderDataset from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss -from TTS.encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model +from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model from TTS.encoder.utils.samplers import PerfectBatchSampler from TTS.encoder.utils.training import init_training from TTS.encoder.utils.visual import plot_embeddings @@ -55,14 +55,13 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False batch_size=num_classes_in_batch*num_utter_per_class, # total batch size num_classes_in_batch=num_classes_in_batch, num_gpus=1, - shuffle=False if is_val else True, + shuffle=not is_val, drop_last=True) if len(classes) < num_classes_in_batch: if is_val: raise RuntimeError(f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !") - else: - raise RuntimeError(f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !") + raise RuntimeError(f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !") # set the classes to avoid get wrong class_id when the number of training and eval classes are not equal if is_val: @@ -73,16 +72,14 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False num_workers=c.num_loader_workers, batch_sampler=sampler, collate_fn=dataset.collate_fn, - ) + ) return loader, classes, dataset.get_map_classid_to_classname() def evaluation(model, criterion, data_loader, global_step): eval_loss = 0 - for step, data in enumerate(data_loader): + for _, data in enumerate(data_loader): with torch.no_grad(): - start_time = time.time() - # setup input data inputs, labels = data @@ -121,7 +118,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, for epoch in range(c.epochs): tot_loss = 0 epoch_time = 0 - for step, data in enumerate(data_loader): + for _, data in enumerate(data_loader): start_time = time.time() # setup input data @@ -129,22 +126,19 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) - """ # ToDo: move it to a unit test - labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) - inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) - idx = 0 - for j in range(0, c.num_classes_in_batch, 1): - for i in range(j, len(labels), c.num_classes_in_batch): - if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])): - print("Invalid") - print(labels) - exit() - idx += 1 - labels = labels_converted - inputs = inputs_converted - print(labels) - print(inputs.shape)""" + # labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) + # inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) + # idx = 0 + # for j in range(0, c.num_classes_in_batch, 1): + # for i in range(j, len(labels), c.num_classes_in_batch): + # if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])): + # print("Invalid") + # print(labels) + # exit() + # idx += 1 + # labels = labels_converted + # inputs = inputs_converted loader_time = time.time() - end_time global_step += 1 @@ -212,12 +206,12 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch) end_time = time.time() - + print("") print( - " | > Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} " + ">>> Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} " "EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format( - epoch, tot_loss/len(data_loader), grad_norm, epoch_time, avg_loader_time, current_lr + epoch, tot_loss/len(data_loader), grad_norm, epoch_time, avg_loader_time ), flush=True, ) diff --git a/TTS/encoder/dataset.py b/TTS/encoder/dataset.py index 72078f7d..ef24daa1 100644 --- a/TTS/encoder/dataset.py +++ b/TTS/encoder/dataset.py @@ -1,10 +1,9 @@ import random -import numpy as np import torch from torch.utils.data import Dataset -from TTS.encoder.utils.generic_utils import AugmentWAV, Storage +from TTS.encoder.utils.generic_utils import AugmentWAV class EncoderDataset(Dataset): def __init__( @@ -33,7 +32,7 @@ class EncoderDataset(Dataset): self.ap = ap self.verbose = verbose self.use_torch_spec = use_torch_spec - self.__parse_items() + self.classes, self.items = self.__parse_items() self.classname_to_classid = {key: i for i, key in enumerate(self.classes)} @@ -78,15 +77,15 @@ class EncoderDataset(Dataset): k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class } - self.classes = list(class_to_utters.keys()) - self.classes.sort() + classes = list(class_to_utters.keys()) + classes.sort() new_items = [] for item in self.items: path_ = item[1] class_name = item[2] # ignore filtered classes - if class_name not in self.classes: + if class_name not in classes: continue # ignore small audios if self.load_wav(path_).shape[0] - self.seq_len <= 0: @@ -94,9 +93,7 @@ class EncoderDataset(Dataset): new_items.append({"wav_file_path": path_, "class_name": class_name}) - self.items = new_items - - + return classes, new_items def __len__(self): return len(self.items) diff --git a/TTS/encoder/utils/generic_utils.py b/TTS/encoder/utils/generic_utils.py index c87645dd..17f1c3d9 100644 --- a/TTS/encoder/utils/generic_utils.py +++ b/TTS/encoder/utils/generic_utils.py @@ -3,7 +3,6 @@ import glob import os import random import re -from multiprocessing import Manager import numpy as np from scipy import signal @@ -13,50 +12,6 @@ from TTS.encoder.models.resnet import ResNetSpeakerEncoder from TTS.utils.io import save_fsspec -class Storage(object): - def __init__(self, maxsize, storage_batchs, num_classes_in_batch, num_threads=8): - # use multiprocessing for threading safe - self.storage = Manager().list() - self.maxsize = maxsize - self.num_classes_in_batch = num_classes_in_batch - self.num_threads = num_threads - self.ignore_last_batch = False - - if storage_batchs >= 3: - self.ignore_last_batch = True - - # used for fast random sample - self.safe_storage_size = self.maxsize - self.num_threads - if self.ignore_last_batch: - self.safe_storage_size -= self.num_classes_in_batch - - def __len__(self): - return len(self.storage) - - def full(self): - return len(self.storage) >= self.maxsize - - def append(self, item): - # if storage is full, remove an item - if self.full(): - self.storage.pop(0) - - self.storage.append(item) - - def get_random_sample(self): - # safe storage size considering all threads remove one item from storage in same time - storage_size = len(self.storage) - self.num_threads - - if self.ignore_last_batch: - storage_size -= self.num_classes_in_batch - - return self.storage[random.randint(0, storage_size)] - - def get_random_sample_fast(self): - """Call this method only when storage is full""" - return self.storage[random.randint(0, self.safe_storage_size)] - - class AugmentWAV(object): def __init__(self, ap, augmentation_config): diff --git a/TTS/encoder/utils/samplers.py b/TTS/encoder/utils/samplers.py index e8d2a601..935aa067 100644 --- a/TTS/encoder/utils/samplers.py +++ b/TTS/encoder/utils/samplers.py @@ -1,4 +1,3 @@ -import torch import random from torch.utils.data.sampler import Sampler, SubsetRandomSampler @@ -12,6 +11,7 @@ class SubsetSampler(Sampler): """ def __init__(self, indices): + super().__init__(indices) self.indices = indices def __iter__(self): @@ -35,15 +35,17 @@ class PerfectBatchSampler(Sampler): """ def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False): - + super().__init__(dataset_items) assert batch_size % (num_classes_in_batch * num_gpus) == 0, ( 'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).') label_indices = {} - for idx in range(len(dataset_items)): - label = dataset_items[idx]['class_name'] - if label not in label_indices: label_indices[label] = [] - label_indices[label].append(idx) + for idx, item in enumerate(dataset_items): + label = item['class_name'] + if label not in label_indices.keys(): + label_indices[label] = [idx] + else: + label_indices[label].append(idx) if shuffle: self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes] @@ -68,16 +70,16 @@ class PerfectBatchSampler(Sampler): while True: b = [] - for i in range(len(iters)): + for i, it in enumerate(iters): if valid_samplers_idx is not None and i not in valid_samplers_idx: continue - it = iters[i] idx = next(it, None) if idx is None: done = True break b.append(idx) - if done: break + if done: + break batch += b if len(batch) == self._batch_size: yield batch @@ -97,4 +99,4 @@ class PerfectBatchSampler(Sampler): def __len__(self): class_batch_size = self._batch_size // self._num_classes_in_batch - return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers) \ No newline at end of file + return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers)