Fix lint checks

This commit is contained in:
Edresson Casanova 2022-03-04 18:02:08 -03:00
parent 33fd07a209
commit 711a46506f
5 changed files with 49 additions and 103 deletions

View File

@ -1,5 +1,4 @@
import argparse import argparse
import os
import torch import torch
from argparse import RawTextHelpFormatter from argparse import RawTextHelpFormatter
@ -71,18 +70,17 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
predicted_label = None predicted_label = None
if class_name is not None and predicted_label is not None: if class_name is not None and predicted_label is not None:
is_equal = int(class_name == predicted_label) is_equal = int(class_name == predicted_label)
if class_name not in class_acc_dict: if class_name not in class_acc_dict:
class_acc_dict[class_name] = [is_equal] class_acc_dict[class_name] = [is_equal]
else: else:
class_acc_dict[class_name].append(is_equal) class_acc_dict[class_name].append(is_equal)
else: else:
print("Error: class_name or/and predicted_label are None") raise RuntimeError("Error: class_name or/and predicted_label are None")
exit()
acc_avg = 0 acc_avg = 0
for key in class_acc_dict: for key, values in class_acc_dict.items():
acc = sum(class_acc_dict[key])/len(class_acc_dict[key]) acc = sum(values)/len(values)
print("Class", key, "Accuracy:", acc) print("Class", key, "Accuracy:", acc)
acc_avg += acc acc_avg += acc

View File

@ -12,7 +12,7 @@ from trainer.torch import NoamLR
from TTS.encoder.dataset import EncoderDataset from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model
from TTS.encoder.utils.samplers import PerfectBatchSampler from TTS.encoder.utils.samplers import PerfectBatchSampler
from TTS.encoder.utils.training import init_training from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings from TTS.encoder.utils.visual import plot_embeddings
@ -55,14 +55,13 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
batch_size=num_classes_in_batch*num_utter_per_class, # total batch size batch_size=num_classes_in_batch*num_utter_per_class, # total batch size
num_classes_in_batch=num_classes_in_batch, num_classes_in_batch=num_classes_in_batch,
num_gpus=1, num_gpus=1,
shuffle=False if is_val else True, shuffle=not is_val,
drop_last=True) drop_last=True)
if len(classes) < num_classes_in_batch: if len(classes) < num_classes_in_batch:
if is_val: if is_val:
raise RuntimeError(f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !") raise RuntimeError(f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !")
else: raise RuntimeError(f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !")
raise RuntimeError(f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !")
# set the classes to avoid get wrong class_id when the number of training and eval classes are not equal # set the classes to avoid get wrong class_id when the number of training and eval classes are not equal
if is_val: if is_val:
@ -79,10 +78,8 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
def evaluation(model, criterion, data_loader, global_step): def evaluation(model, criterion, data_loader, global_step):
eval_loss = 0 eval_loss = 0
for step, data in enumerate(data_loader): for _, data in enumerate(data_loader):
with torch.no_grad(): with torch.no_grad():
start_time = time.time()
# setup input data # setup input data
inputs, labels = data inputs, labels = data
@ -121,7 +118,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
for epoch in range(c.epochs): for epoch in range(c.epochs):
tot_loss = 0 tot_loss = 0
epoch_time = 0 epoch_time = 0
for step, data in enumerate(data_loader): for _, data in enumerate(data_loader):
start_time = time.time() start_time = time.time()
# setup input data # setup input data
@ -129,22 +126,19 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
"""
# ToDo: move it to a unit test # ToDo: move it to a unit test
labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) # labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) # inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
idx = 0 # idx = 0
for j in range(0, c.num_classes_in_batch, 1): # for j in range(0, c.num_classes_in_batch, 1):
for i in range(j, len(labels), c.num_classes_in_batch): # for i in range(j, len(labels), c.num_classes_in_batch):
if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])): # if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])):
print("Invalid") # print("Invalid")
print(labels) # print(labels)
exit() # exit()
idx += 1 # idx += 1
labels = labels_converted # labels = labels_converted
inputs = inputs_converted # inputs = inputs_converted
print(labels)
print(inputs.shape)"""
loader_time = time.time() - end_time loader_time = time.time() - end_time
global_step += 1 global_step += 1
@ -215,9 +209,9 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
print("") print("")
print( print(
" | > Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} " ">>> Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} "
"EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format( "EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format(
epoch, tot_loss/len(data_loader), grad_norm, epoch_time, avg_loader_time, current_lr epoch, tot_loss/len(data_loader), grad_norm, epoch_time, avg_loader_time
), ),
flush=True, flush=True,
) )

View File

@ -1,10 +1,9 @@
import random import random
import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from TTS.encoder.utils.generic_utils import AugmentWAV, Storage from TTS.encoder.utils.generic_utils import AugmentWAV
class EncoderDataset(Dataset): class EncoderDataset(Dataset):
def __init__( def __init__(
@ -33,7 +32,7 @@ class EncoderDataset(Dataset):
self.ap = ap self.ap = ap
self.verbose = verbose self.verbose = verbose
self.use_torch_spec = use_torch_spec self.use_torch_spec = use_torch_spec
self.__parse_items() self.classes, self.items = self.__parse_items()
self.classname_to_classid = {key: i for i, key in enumerate(self.classes)} self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
@ -78,15 +77,15 @@ class EncoderDataset(Dataset):
k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class
} }
self.classes = list(class_to_utters.keys()) classes = list(class_to_utters.keys())
self.classes.sort() classes.sort()
new_items = [] new_items = []
for item in self.items: for item in self.items:
path_ = item[1] path_ = item[1]
class_name = item[2] class_name = item[2]
# ignore filtered classes # ignore filtered classes
if class_name not in self.classes: if class_name not in classes:
continue continue
# ignore small audios # ignore small audios
if self.load_wav(path_).shape[0] - self.seq_len <= 0: if self.load_wav(path_).shape[0] - self.seq_len <= 0:
@ -94,9 +93,7 @@ class EncoderDataset(Dataset):
new_items.append({"wav_file_path": path_, "class_name": class_name}) new_items.append({"wav_file_path": path_, "class_name": class_name})
self.items = new_items return classes, new_items
def __len__(self): def __len__(self):
return len(self.items) return len(self.items)

View File

@ -3,7 +3,6 @@ import glob
import os import os
import random import random
import re import re
from multiprocessing import Manager
import numpy as np import numpy as np
from scipy import signal from scipy import signal
@ -13,50 +12,6 @@ from TTS.encoder.models.resnet import ResNetSpeakerEncoder
from TTS.utils.io import save_fsspec from TTS.utils.io import save_fsspec
class Storage(object):
def __init__(self, maxsize, storage_batchs, num_classes_in_batch, num_threads=8):
# use multiprocessing for threading safe
self.storage = Manager().list()
self.maxsize = maxsize
self.num_classes_in_batch = num_classes_in_batch
self.num_threads = num_threads
self.ignore_last_batch = False
if storage_batchs >= 3:
self.ignore_last_batch = True
# used for fast random sample
self.safe_storage_size = self.maxsize - self.num_threads
if self.ignore_last_batch:
self.safe_storage_size -= self.num_classes_in_batch
def __len__(self):
return len(self.storage)
def full(self):
return len(self.storage) >= self.maxsize
def append(self, item):
# if storage is full, remove an item
if self.full():
self.storage.pop(0)
self.storage.append(item)
def get_random_sample(self):
# safe storage size considering all threads remove one item from storage in same time
storage_size = len(self.storage) - self.num_threads
if self.ignore_last_batch:
storage_size -= self.num_classes_in_batch
return self.storage[random.randint(0, storage_size)]
def get_random_sample_fast(self):
"""Call this method only when storage is full"""
return self.storage[random.randint(0, self.safe_storage_size)]
class AugmentWAV(object): class AugmentWAV(object):
def __init__(self, ap, augmentation_config): def __init__(self, ap, augmentation_config):

View File

@ -1,4 +1,3 @@
import torch
import random import random
from torch.utils.data.sampler import Sampler, SubsetRandomSampler from torch.utils.data.sampler import Sampler, SubsetRandomSampler
@ -12,6 +11,7 @@ class SubsetSampler(Sampler):
""" """
def __init__(self, indices): def __init__(self, indices):
super().__init__(indices)
self.indices = indices self.indices = indices
def __iter__(self): def __iter__(self):
@ -35,15 +35,17 @@ class PerfectBatchSampler(Sampler):
""" """
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False): def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False):
super().__init__(dataset_items)
assert batch_size % (num_classes_in_batch * num_gpus) == 0, ( assert batch_size % (num_classes_in_batch * num_gpus) == 0, (
'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).') 'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).')
label_indices = {} label_indices = {}
for idx in range(len(dataset_items)): for idx, item in enumerate(dataset_items):
label = dataset_items[idx]['class_name'] label = item['class_name']
if label not in label_indices: label_indices[label] = [] if label not in label_indices.keys():
label_indices[label].append(idx) label_indices[label] = [idx]
else:
label_indices[label].append(idx)
if shuffle: if shuffle:
self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes] self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes]
@ -68,16 +70,16 @@ class PerfectBatchSampler(Sampler):
while True: while True:
b = [] b = []
for i in range(len(iters)): for i, it in enumerate(iters):
if valid_samplers_idx is not None and i not in valid_samplers_idx: if valid_samplers_idx is not None and i not in valid_samplers_idx:
continue continue
it = iters[i]
idx = next(it, None) idx = next(it, None)
if idx is None: if idx is None:
done = True done = True
break break
b.append(idx) b.append(idx)
if done: break if done:
break
batch += b batch += b
if len(batch) == self._batch_size: if len(batch) == self._batch_size:
yield batch yield batch