Add evaluation during encoder training

This commit is contained in:
Edresson Casanova 2022-03-04 17:30:59 -03:00
parent 0e372e0b9b
commit 33fd07a209
5 changed files with 203 additions and 121 deletions

View File

@ -33,29 +33,41 @@ print(" > Number of GPUs: ", num_gpus)
def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False): def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False):
if is_val: num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class
loader = None num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch
else:
dataset = EncoderDataset( dataset = EncoderDataset(
ap, ap,
meta_data_eval if is_val else meta_data_train, meta_data_eval if is_val else meta_data_train,
voice_len=c.voice_len, voice_len=c.voice_len,
num_utter_per_class=c.num_utter_per_class, num_utter_per_class=num_utter_per_class,
num_classes_in_batch=c.num_classes_in_batch, num_classes_in_batch=num_classes_in_batch,
verbose=verbose, verbose=verbose,
augmentation_config=c.audio_augmentation if not is_val else None, augmentation_config=c.audio_augmentation if not is_val else None,
use_torch_spec=c.model_params.get("use_torch_spec", False), use_torch_spec=c.model_params.get("use_torch_spec", False),
) )
# get classes list
classes = dataset.get_class_list()
sampler = PerfectBatchSampler( sampler = PerfectBatchSampler(
dataset.items, dataset.items,
dataset.get_class_list(), classes,
batch_size=c.num_classes_in_batch*c.num_utter_per_class, # total batch size batch_size=num_classes_in_batch*num_utter_per_class, # total batch size
num_classes_in_batch=c.num_classes_in_batch, num_classes_in_batch=num_classes_in_batch,
num_gpus=1, num_gpus=1,
shuffle=False if is_val else True, shuffle=False if is_val else True,
drop_last=True) drop_last=True)
if len(classes) < num_classes_in_batch:
if is_val:
raise RuntimeError(f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !")
else:
raise RuntimeError(f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !")
# set the classes to avoid get wrong class_id when the number of training and eval classes are not equal
if is_val:
dataset.set_classes(train_classes)
loader = DataLoader( loader = DataLoader(
dataset, dataset,
num_workers=c.num_loader_workers, num_workers=c.num_loader_workers,
@ -63,19 +75,53 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
) )
return loader, dataset.get_num_classes(), dataset.get_map_classid_to_classname() return loader, classes, dataset.get_map_classid_to_classname()
def evaluation(model, criterion, data_loader, global_step):
eval_loss = 0
for step, data in enumerate(data_loader):
with torch.no_grad():
start_time = time.time()
def train(model, optimizer, scheduler, criterion, data_loader, global_step): # setup input data
inputs, labels = data
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
labels = torch.transpose(labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1).reshape(labels.shape)
inputs = torch.transpose(inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
# dispatch data to GPU
if use_cuda:
inputs = inputs.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
# forward pass model
outputs = model(inputs)
# loss computation
loss = criterion(outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels)
eval_loss += loss.item()
eval_avg_loss = eval_loss/len(data_loader)
# save stats
dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss})
# plot the last batch in the evaluation
figures = {
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
}
dashboard_logger.eval_figures(global_step, figures)
return eval_avg_loss
def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step):
model.train() model.train()
epoch_time = 0
best_loss = float("inf") best_loss = float("inf")
avg_loss = 0
avg_loss_all = 0
avg_loader_time = 0 avg_loader_time = 0
end_time = time.time() end_time = time.time()
print(len(data_loader)) for epoch in range(c.epochs):
for _, data in enumerate(data_loader): tot_loss = 0
epoch_time = 0
for step, data in enumerate(data_loader):
start_time = time.time() start_time = time.time()
# setup input data # setup input data
@ -84,6 +130,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step):
labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
""" """
# ToDo: move it to a unit test
labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
idx = 0 idx = 0
@ -124,8 +171,10 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step):
step_time = time.time() - start_time step_time = time.time() - start_time
epoch_time += step_time epoch_time += step_time
# Averaged Loss and Averaged Loader Time # acumulate the total epoch loss
avg_loss = 0.01 * loss.item() + 0.99 * avg_loss if avg_loss != 0 else loss.item() tot_loss += loss.item()
# Averaged Loader Time
num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1 num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1
avg_loader_time = ( avg_loader_time = (
1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time 1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time
@ -137,7 +186,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step):
if global_step % c.steps_plot_stats == 0: if global_step % c.steps_plot_stats == 0:
# Plot Training Epoch Stats # Plot Training Epoch Stats
train_stats = { train_stats = {
"loss": avg_loss, "loss": loss.item(),
"lr": current_lr, "lr": current_lr,
"grad_norm": grad_norm, "grad_norm": grad_norm,
"step_time": step_time, "step_time": step_time,
@ -151,30 +200,51 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step):
if global_step % c.print_step == 0: if global_step % c.print_step == 0:
print( print(
" | > Step:{} Loss:{:.5f} AvgLoss:{:.5f} GradNorm:{:.5f} " " | > Step:{} Loss:{:.5f} GradNorm:{:.5f} "
"StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format( "StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format(
global_step, loss.item(), avg_loss, grad_norm, step_time, loader_time, avg_loader_time, current_lr global_step, loss.item(), grad_norm, step_time, loader_time, avg_loader_time, current_lr
), ),
flush=True, flush=True,
) )
avg_loss_all += avg_loss
if global_step >= c.max_train_step or global_step % c.save_step == 0: if global_step % c.save_step == 0:
# save best model only # save model
best_loss = save_best_model(model, optimizer, criterion, avg_loss, best_loss, OUT_PATH, global_step) save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch)
avg_loss_all = 0
if global_step >= c.max_train_step:
break
end_time = time.time() end_time = time.time()
return avg_loss, global_step print("")
print(
" | > Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} "
"EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format(
epoch, tot_loss/len(data_loader), grad_norm, epoch_time, avg_loader_time, current_lr
),
flush=True,
)
# evaluation
if c.run_eval:
model.eval()
eval_loss = evaluation(model, criterion, eval_data_loader, global_step)
print("\n\n")
print("--> EVAL PERFORMANCE")
print(
" | > Epoch:{} AvgLoss: {:.5f} ".format(
epoch, eval_loss
),
flush=True,
)
# save the best checkpoint
best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch)
model.train()
return best_loss, global_step
def main(args): # pylint: disable=redefined-outer-name def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined # pylint: disable=global-variable-undefined
global meta_data_train global meta_data_train
global meta_data_eval global meta_data_eval
global train_classes
ap = AudioProcessor(**c.audio) ap = AudioProcessor(**c.audio)
model = setup_speaker_encoder_model(c) model = setup_speaker_encoder_model(c)
@ -184,8 +254,12 @@ def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True) meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)
train_data_loader, num_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True) train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True)
# eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True) if c.run_eval:
eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True)
else:
eval_data_loader = None
num_classes = len(train_classes)
if c.loss == "ge2e": if c.loss == "ge2e":
criterion = GE2ELoss(loss_method="softmax") criterion = GE2ELoss(loss_method="softmax")
@ -235,7 +309,7 @@ def main(args): # pylint: disable=redefined-outer-name
criterion.cuda() criterion.cuda()
global_step = args.restore_step global_step = args.restore_step
_, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, global_step) _, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -39,15 +39,18 @@ class BaseEncoderConfig(BaseTrainingConfig):
# logging params # logging params
tb_model_param_stats: bool = False tb_model_param_stats: bool = False
steps_plot_stats: int = 10 steps_plot_stats: int = 10
checkpoint: bool = True epochs: int = 10000
save_step: int = 1000 save_step: int = 1000
print_step: int = 20 print_step: int = 20
run_eval: bool = False
# data loader # data loader
num_classes_in_batch: int = MISSING num_classes_in_batch: int = MISSING
num_utter_per_class: int = MISSING num_utter_per_class: int = MISSING
eval_num_classes_in_batch: int = MISSING
eval_num_utter_per_class: int = MISSING
num_loader_workers: int = MISSING num_loader_workers: int = MISSING
skip_classes: bool = False
voice_len: float = 1.6 voice_len: float = 1.6
def check_values(self): def check_values(self):

View File

@ -104,7 +104,11 @@ class EncoderDataset(Dataset):
return len(self.classes) return len(self.classes)
def get_class_list(self): def get_class_list(self):
return list(self.classes) return self.classes
def set_classes(self, classes):
self.classes = classes
self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
def get_map_classid_to_classname(self): def get_map_classid_to_classname(self):
return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items()) return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())

View File

@ -209,7 +209,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s
save_fsspec(state, checkpoint_path) save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step): def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch):
if model_loss < best_loss: if model_loss < best_loss:
new_state_dict = model.state_dict() new_state_dict = model.state_dict()
state = { state = {
@ -217,6 +217,7 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path
"optimizer": optimizer.state_dict(), "optimizer": optimizer.state_dict(),
"criterion": criterion.state_dict(), "criterion": criterion.state_dict(),
"step": current_step, "step": current_step,
"epoch": epoch,
"loss": model_loss, "loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"), "date": datetime.date.today().strftime("%B %d, %Y"),
} }

View File

@ -36,7 +36,7 @@ class PerfectBatchSampler(Sampler):
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False): def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False):
assert batch_size % (len(classes) * num_gpus) == 0, ( assert batch_size % (num_classes_in_batch * num_gpus) == 0, (
'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).') 'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).')
label_indices = {} label_indices = {}