From dc3f909ddc80e54f70e7c506c0e90fc217dfb185 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 11 May 2018 08:38:07 -0700 Subject: [PATCH] Separate backward pass for stop-token prediction --- train.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 7fa485b6..7ca06ed8 100644 --- a/train.py +++ b/train.py @@ -59,7 +59,7 @@ LOG_DIR = OUT_PATH tb = SummaryWriter(LOG_DIR) -def train(model, criterion, criterion_st, data_loader, optimizer, epoch): +def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch): model = model.train() epoch_time = 0 avg_linear_loss = 0 @@ -88,10 +88,15 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch): # setup lr current_lr = lr_decay(c.lr, current_step, c.warmup_steps) + for params_group in optimizer.param_groups: params_group['lr'] = current_lr + + for params_group in optimizer_st.param_groups: + params_group['lr'] = current_lr optimizer.zero_grad() + optimizer_st.zero_grad() # dispatch data to GPU if use_cuda: @@ -112,16 +117,25 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch): + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_input[:, :, :n_priority_freq], mel_lengths) - loss = mel_loss + linear_loss + stop_loss + loss = mel_loss + linear_loss - # backpass and check the grad norm - loss.backward() + # backpass and check the grad norm for spec losses + loss.backward(retain_graph=True) grad_norm, skip_flag = check_update(model, 0.5, 100) if skip_flag: optimizer.zero_grad() print(" | > Iteration skipped!!") continue optimizer.step() + + # backpass and check the grad norm for stop loss + stop_loss.backward() + grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, 0.5, 100) + if skip_flag: + optimizer_st.zero_grad() + print(" | > Iteration skipped fro stopnet!!") + continue + optimizer_st.step() step_time = time.time() - start_time epoch_time += step_time @@ -131,7 +145,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch): ('linear_loss', linear_loss.item()), ('mel_loss', mel_loss.item()), ('stop_loss', stop_loss.item()), - ('grad_norm', grad_norm.item())]) + ('grad_norm', grad_norm.item()), + ('grad_norm_st', grad_norm_st.item())]) avg_linear_loss += linear_loss.item() avg_mel_loss += mel_loss.item() avg_stop_loss += stop_loss.item() @@ -144,6 +159,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch): tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'], current_step) tb.add_scalar('Params/GradNorm', grad_norm, current_step) + tb.add_scalar('Params/GradNormSt', grad_norm_st, current_step) tb.add_scalar('Time/StepTime', step_time, current_step) if current_step % c.save_step == 0: @@ -345,6 +361,7 @@ def main(args): c.r) optimizer = optim.Adam(model.parameters(), lr=c.lr) + optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr) criterion = L1LossMasked() criterion_st = nn.BCELoss() @@ -378,7 +395,7 @@ def main(args): for epoch in range(0, c.epochs): train_loss, current_step = train( - model, criterion, criterion_st, train_loader, optimizer, epoch) + model, criterion, criterion_st, train_loader, optimizer, optimizer_st, epoch) val_loss = evaluate(model, criterion, criterion_st, val_loader, current_step) best_loss = save_best_model(model, optimizer, val_loss, best_loss, OUT_PATH,