Add evaluation during encoder training

This commit is contained in:
Edresson Casanova 2022-03-04 17:30:59 -03:00
parent 0e372e0b9b
commit 33fd07a209
5 changed files with 203 additions and 121 deletions

View File

@ -33,148 +33,218 @@ print(" > Number of GPUs: ", num_gpus)
def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False): def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False):
num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class
num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch
dataset = EncoderDataset(
ap,
meta_data_eval if is_val else meta_data_train,
voice_len=c.voice_len,
num_utter_per_class=num_utter_per_class,
num_classes_in_batch=num_classes_in_batch,
verbose=verbose,
augmentation_config=c.audio_augmentation if not is_val else None,
use_torch_spec=c.model_params.get("use_torch_spec", False),
)
# get classes list
classes = dataset.get_class_list()
sampler = PerfectBatchSampler(
dataset.items,
classes,
batch_size=num_classes_in_batch*num_utter_per_class, # total batch size
num_classes_in_batch=num_classes_in_batch,
num_gpus=1,
shuffle=False if is_val else True,
drop_last=True)
if len(classes) < num_classes_in_batch:
if is_val:
raise RuntimeError(f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !")
else:
raise RuntimeError(f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !")
# set the classes to avoid get wrong class_id when the number of training and eval classes are not equal
if is_val: if is_val:
loader = None dataset.set_classes(train_classes)
else:
dataset = EncoderDataset(
ap,
meta_data_eval if is_val else meta_data_train,
voice_len=c.voice_len,
num_utter_per_class=c.num_utter_per_class,
num_classes_in_batch=c.num_classes_in_batch,
verbose=verbose,
augmentation_config=c.audio_augmentation if not is_val else None,
use_torch_spec=c.model_params.get("use_torch_spec", False),
)
sampler = PerfectBatchSampler( loader = DataLoader(
dataset.items, dataset,
dataset.get_class_list(), num_workers=c.num_loader_workers,
batch_size=c.num_classes_in_batch*c.num_utter_per_class, # total batch size batch_sampler=sampler,
num_classes_in_batch=c.num_classes_in_batch, collate_fn=dataset.collate_fn,
num_gpus=1, )
shuffle=False if is_val else True,
drop_last=True)
loader = DataLoader( return loader, classes, dataset.get_map_classid_to_classname()
dataset,
num_workers=c.num_loader_workers,
batch_sampler=sampler,
collate_fn=dataset.collate_fn,
)
return loader, dataset.get_num_classes(), dataset.get_map_classid_to_classname() def evaluation(model, criterion, data_loader, global_step):
eval_loss = 0
for step, data in enumerate(data_loader):
with torch.no_grad():
start_time = time.time()
# setup input data
inputs, labels = data
def train(model, optimizer, scheduler, criterion, data_loader, global_step): # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
labels = torch.transpose(labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1).reshape(labels.shape)
inputs = torch.transpose(inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
# dispatch data to GPU
if use_cuda:
inputs = inputs.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
# forward pass model
outputs = model(inputs)
# loss computation
loss = criterion(outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels)
eval_loss += loss.item()
eval_avg_loss = eval_loss/len(data_loader)
# save stats
dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss})
# plot the last batch in the evaluation
figures = {
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
}
dashboard_logger.eval_figures(global_step, figures)
return eval_avg_loss
def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step):
model.train() model.train()
epoch_time = 0
best_loss = float("inf") best_loss = float("inf")
avg_loss = 0
avg_loss_all = 0
avg_loader_time = 0 avg_loader_time = 0
end_time = time.time() end_time = time.time()
print(len(data_loader)) for epoch in range(c.epochs):
for _, data in enumerate(data_loader): tot_loss = 0
start_time = time.time() epoch_time = 0
for step, data in enumerate(data_loader):
start_time = time.time()
# setup input data # setup input data
inputs, labels = data inputs, labels = data
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
""" """
labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) # ToDo: move it to a unit test
inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
idx = 0 inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
for j in range(0, c.num_classes_in_batch, 1): idx = 0
for i in range(j, len(labels), c.num_classes_in_batch): for j in range(0, c.num_classes_in_batch, 1):
if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])): for i in range(j, len(labels), c.num_classes_in_batch):
print("Invalid") if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])):
print(labels) print("Invalid")
exit() print(labels)
idx += 1 exit()
labels = labels_converted idx += 1
inputs = inputs_converted labels = labels_converted
print(labels) inputs = inputs_converted
print(inputs.shape)""" print(labels)
print(inputs.shape)"""
loader_time = time.time() - end_time loader_time = time.time() - end_time
global_step += 1 global_step += 1
# setup lr # setup lr
if c.lr_decay: if c.lr_decay:
scheduler.step() scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
# dispatch data to GPU # dispatch data to GPU
if use_cuda: if use_cuda:
inputs = inputs.cuda(non_blocking=True) inputs = inputs.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True) labels = labels.cuda(non_blocking=True)
# forward pass model # forward pass model
outputs = model(inputs) outputs = model(inputs)
# loss computation # loss computation
loss = criterion(outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels) loss = criterion(outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels)
loss.backward() loss.backward()
grad_norm, _ = check_update(model, c.grad_clip) grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step() optimizer.step()
step_time = time.time() - start_time step_time = time.time() - start_time
epoch_time += step_time epoch_time += step_time
# Averaged Loss and Averaged Loader Time # acumulate the total epoch loss
avg_loss = 0.01 * loss.item() + 0.99 * avg_loss if avg_loss != 0 else loss.item() tot_loss += loss.item()
num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1
avg_loader_time = (
1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time
if avg_loader_time != 0
else loader_time
)
current_lr = optimizer.param_groups[0]["lr"]
if global_step % c.steps_plot_stats == 0: # Averaged Loader Time
# Plot Training Epoch Stats num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1
train_stats = { avg_loader_time = (
"loss": avg_loss, 1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time
"lr": current_lr, if avg_loader_time != 0
"grad_norm": grad_norm, else loader_time
"step_time": step_time,
"avg_loader_time": avg_loader_time,
}
dashboard_logger.train_epoch_stats(global_step, train_stats)
figures = {
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
}
dashboard_logger.train_figures(global_step, figures)
if global_step % c.print_step == 0:
print(
" | > Step:{} Loss:{:.5f} AvgLoss:{:.5f} GradNorm:{:.5f} "
"StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format(
global_step, loss.item(), avg_loss, grad_norm, step_time, loader_time, avg_loader_time, current_lr
),
flush=True,
) )
avg_loss_all += avg_loss current_lr = optimizer.param_groups[0]["lr"]
if global_step >= c.max_train_step or global_step % c.save_step == 0: if global_step % c.steps_plot_stats == 0:
# save best model only # Plot Training Epoch Stats
best_loss = save_best_model(model, optimizer, criterion, avg_loss, best_loss, OUT_PATH, global_step) train_stats = {
avg_loss_all = 0 "loss": loss.item(),
if global_step >= c.max_train_step: "lr": current_lr,
break "grad_norm": grad_norm,
"step_time": step_time,
"avg_loader_time": avg_loader_time,
}
dashboard_logger.train_epoch_stats(global_step, train_stats)
figures = {
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
}
dashboard_logger.train_figures(global_step, figures)
end_time = time.time() if global_step % c.print_step == 0:
print(
" | > Step:{} Loss:{:.5f} GradNorm:{:.5f} "
"StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format(
global_step, loss.item(), grad_norm, step_time, loader_time, avg_loader_time, current_lr
),
flush=True,
)
return avg_loss, global_step if global_step % c.save_step == 0:
# save model
save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch)
end_time = time.time()
print("")
print(
" | > Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} "
"EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format(
epoch, tot_loss/len(data_loader), grad_norm, epoch_time, avg_loader_time, current_lr
),
flush=True,
)
# evaluation
if c.run_eval:
model.eval()
eval_loss = evaluation(model, criterion, eval_data_loader, global_step)
print("\n\n")
print("--> EVAL PERFORMANCE")
print(
" | > Epoch:{} AvgLoss: {:.5f} ".format(
epoch, eval_loss
),
flush=True,
)
# save the best checkpoint
best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch)
model.train()
return best_loss, global_step
def main(args): # pylint: disable=redefined-outer-name def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined # pylint: disable=global-variable-undefined
global meta_data_train global meta_data_train
global meta_data_eval global meta_data_eval
global train_classes
ap = AudioProcessor(**c.audio) ap = AudioProcessor(**c.audio)
model = setup_speaker_encoder_model(c) model = setup_speaker_encoder_model(c)
@ -184,8 +254,12 @@ def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True) meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)
train_data_loader, num_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True) train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True)
# eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True) if c.run_eval:
eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True)
else:
eval_data_loader = None
num_classes = len(train_classes)
if c.loss == "ge2e": if c.loss == "ge2e":
criterion = GE2ELoss(loss_method="softmax") criterion = GE2ELoss(loss_method="softmax")
@ -235,7 +309,7 @@ def main(args): # pylint: disable=redefined-outer-name
criterion.cuda() criterion.cuda()
global_step = args.restore_step global_step = args.restore_step
_, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, global_step) _, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -39,15 +39,18 @@ class BaseEncoderConfig(BaseTrainingConfig):
# logging params # logging params
tb_model_param_stats: bool = False tb_model_param_stats: bool = False
steps_plot_stats: int = 10 steps_plot_stats: int = 10
checkpoint: bool = True epochs: int = 10000
save_step: int = 1000 save_step: int = 1000
print_step: int = 20 print_step: int = 20
run_eval: bool = False
# data loader # data loader
num_classes_in_batch: int = MISSING num_classes_in_batch: int = MISSING
num_utter_per_class: int = MISSING num_utter_per_class: int = MISSING
eval_num_classes_in_batch: int = MISSING
eval_num_utter_per_class: int = MISSING
num_loader_workers: int = MISSING num_loader_workers: int = MISSING
skip_classes: bool = False
voice_len: float = 1.6 voice_len: float = 1.6
def check_values(self): def check_values(self):

View File

@ -104,7 +104,11 @@ class EncoderDataset(Dataset):
return len(self.classes) return len(self.classes)
def get_class_list(self): def get_class_list(self):
return list(self.classes) return self.classes
def set_classes(self, classes):
self.classes = classes
self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
def get_map_classid_to_classname(self): def get_map_classid_to_classname(self):
return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items()) return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())

View File

@ -209,7 +209,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s
save_fsspec(state, checkpoint_path) save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step): def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch):
if model_loss < best_loss: if model_loss < best_loss:
new_state_dict = model.state_dict() new_state_dict = model.state_dict()
state = { state = {
@ -217,6 +217,7 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path
"optimizer": optimizer.state_dict(), "optimizer": optimizer.state_dict(),
"criterion": criterion.state_dict(), "criterion": criterion.state_dict(),
"step": current_step, "step": current_step,
"epoch": epoch,
"loss": model_loss, "loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"), "date": datetime.date.today().strftime("%B %d, %Y"),
} }

View File

@ -36,7 +36,7 @@ class PerfectBatchSampler(Sampler):
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False): def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False):
assert batch_size % (len(classes) * num_gpus) == 0, ( assert batch_size % (num_classes_in_batch * num_gpus) == 0, (
'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).') 'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).')
label_indices = {} label_indices = {}