From 4160e8fca1aa7498e2821c3c250ca21e00637848 Mon Sep 17 00:00:00 2001 From: Eren Date: Thu, 5 Jul 2018 17:30:42 +0200 Subject: [PATCH] Change logging for the new cluster system --- train.py | 56 +++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/train.py b/train.py index f570b7fe..34f96829 100644 --- a/train.py +++ b/train.py @@ -78,7 +78,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, mel_input = data[3] mel_lengths = data[4] stop_targets = data[5] - + # set stop targets view, we predict a single stop token per r frames prediction stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() @@ -89,10 +89,10 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, # setup lr current_lr = lr_decay(c.lr, current_step, c.warmup_steps) current_lr_st = lr_decay(c.lr, current_step, c.warmup_steps) - + for params_group in optimizer.param_groups: params_group['lr'] = current_lr - + for params_group in optimizer_st.param_groups: params_group['lr'] = current_lr_st @@ -106,7 +106,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, mel_lengths = mel_lengths.cuda() linear_input = linear_input.cuda() stop_targets = stop_targets.cuda() - + # forward pass mel_output, linear_output, alignments, stop_tokens =\ model.forward(text_input, mel_input) @@ -128,13 +128,13 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, print(" | > Iteration skipped!!") continue optimizer.step() - + # backpass and check the grad norm for stop loss stop_loss.backward() grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, 0.5, 100) if skip_flag: optimizer_st.zero_grad() - print(" | > Iteration skipped fro stopnet!!") + print(" | | > Iteration skipped fro stopnet!!") continue optimizer_st.step() @@ -142,12 +142,23 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch_time += step_time # update - progbar.update(num_iter+1, values=[('total_loss', loss.item()), - ('linear_loss', linear_loss.item()), - ('mel_loss', mel_loss.item()), - ('stop_loss', stop_loss.item()), - ('grad_norm', grad_norm.item()), - ('grad_norm_st', grad_norm_st.item())]) + # progbar.update(num_iter+1, values=[('total_loss', loss.item()), + # ('linear_loss', linear_loss.item()), + # ('mel_loss', mel_loss.item()), + # ('stop_loss', stop_loss.item()), + # ('grad_norm', grad_norm.item()), + # ('grad_norm_st', grad_norm_st.item())]) + + if current_step % c.print_step == 0: + print(" | | > TotalLoss: {:.5f}\t LinearLoss: {:.5f}\t MelLoss: \ + {:.5f}\t StopLoss: {:.5f}\t GradNorm: {:.5f}\t \ + GradNormST: {:.5f}".format(loss.item(), + linear_loss.item(), + mel_loss.item(), + stop_loss.item(), + grad_norm.item(), + grad_norm_st.item())) + avg_linear_loss += linear_loss.item() avg_mel_loss += mel_loss.item() avg_stop_loss += stop_loss.item() @@ -219,7 +230,7 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step): avg_mel_loss = 0 avg_stop_loss = 0 print(" | > Validation") - progbar = Progbar(len(data_loader.dataset) / c.batch_size) + # progbar = Progbar(len(data_loader.dataset) / c.batch_size) n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) with torch.no_grad(): for num_iter, data in enumerate(data_loader): @@ -232,7 +243,7 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step): mel_input = data[3] mel_lengths = data[4] stop_targets = data[5] - + # set stop targets view, we predict a single stop token per r frames prediction stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() @@ -262,10 +273,16 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step): epoch_time += step_time # update - progbar.update(num_iter+1, values=[('total_loss', loss.item()), - ('linear_loss', linear_loss.item()), - ('mel_loss', mel_loss.item()), - ('stop_loss', stop_loss.item())]) + # progbar.update(num_iter+1, values=[('total_loss', loss.item()), + # ('linear_loss', linear_loss.item()), + # ('mel_loss', mel_loss.item()), + # ('stop_loss', stop_loss.item())]) + if current_step % c.print_step == 0: + print(" | | > TotalLoss: {:.5f}\t LinearLoss: {:.5f}\t MelLoss: \ + {:.5f}\t StopLoss: {:.5f}\t".format(loss.item(), + linear_loss.item(), + mel_loss.item(), + stop_loss.item())) avg_linear_loss += linear_loss.item() avg_mel_loss += mel_loss.item() @@ -366,7 +383,7 @@ def main(args): optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr) criterion = L1LossMasked() - criterion_st = nn.BCELoss() + criterion_st = nn.BCELoss() if args.restore_path: checkpoint = torch.load(args.restore_path) @@ -405,6 +422,7 @@ def main(args): train_loss, current_step = train( model, criterion, criterion_st, train_loader, optimizer, optimizer_st, epoch) val_loss = evaluate(model, criterion, criterion_st, val_loader, current_step) + print(" >>> Train Loss: {:.5f}\t Validation Loss: {:.5f}".format(train_loss, val_loss)) best_loss = save_best_model(model, optimizer, val_loss, best_loss, OUT_PATH, current_step, epoch)