Separate backward pass for stop-token prediction

This commit is contained in:
Eren Golge 2018-05-11 08:38:07 -07:00
parent 5a5b9263e2
commit dc3f909ddc
1 changed files with 23 additions and 6 deletions

View File

@ -59,7 +59,7 @@ LOG_DIR = OUT_PATH
tb = SummaryWriter(LOG_DIR) tb = SummaryWriter(LOG_DIR)
def train(model, criterion, criterion_st, data_loader, optimizer, epoch): def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch):
model = model.train() model = model.train()
epoch_time = 0 epoch_time = 0
avg_linear_loss = 0 avg_linear_loss = 0
@ -88,10 +88,15 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
# setup lr # setup lr
current_lr = lr_decay(c.lr, current_step, c.warmup_steps) current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
for params_group in optimizer.param_groups: for params_group in optimizer.param_groups:
params_group['lr'] = current_lr params_group['lr'] = current_lr
for params_group in optimizer_st.param_groups:
params_group['lr'] = current_lr
optimizer.zero_grad() optimizer.zero_grad()
optimizer_st.zero_grad()
# dispatch data to GPU # dispatch data to GPU
if use_cuda: if use_cuda:
@ -112,16 +117,25 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq], + 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_input[:, :, :n_priority_freq], linear_input[:, :, :n_priority_freq],
mel_lengths) mel_lengths)
loss = mel_loss + linear_loss + stop_loss loss = mel_loss + linear_loss
# backpass and check the grad norm # backpass and check the grad norm for spec losses
loss.backward() loss.backward(retain_graph=True)
grad_norm, skip_flag = check_update(model, 0.5, 100) grad_norm, skip_flag = check_update(model, 0.5, 100)
if skip_flag: if skip_flag:
optimizer.zero_grad() optimizer.zero_grad()
print(" | > Iteration skipped!!") print(" | > Iteration skipped!!")
continue continue
optimizer.step() optimizer.step()
# backpass and check the grad norm for stop loss
stop_loss.backward()
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, 0.5, 100)
if skip_flag:
optimizer_st.zero_grad()
print(" | > Iteration skipped fro stopnet!!")
continue
optimizer_st.step()
step_time = time.time() - start_time step_time = time.time() - start_time
epoch_time += step_time epoch_time += step_time
@ -131,7 +145,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
('linear_loss', linear_loss.item()), ('linear_loss', linear_loss.item()),
('mel_loss', mel_loss.item()), ('mel_loss', mel_loss.item()),
('stop_loss', stop_loss.item()), ('stop_loss', stop_loss.item()),
('grad_norm', grad_norm.item())]) ('grad_norm', grad_norm.item()),
('grad_norm_st', grad_norm_st.item())])
avg_linear_loss += linear_loss.item() avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item() avg_mel_loss += mel_loss.item()
avg_stop_loss += stop_loss.item() avg_stop_loss += stop_loss.item()
@ -144,6 +159,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'], tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
current_step) current_step)
tb.add_scalar('Params/GradNorm', grad_norm, current_step) tb.add_scalar('Params/GradNorm', grad_norm, current_step)
tb.add_scalar('Params/GradNormSt', grad_norm_st, current_step)
tb.add_scalar('Time/StepTime', step_time, current_step) tb.add_scalar('Time/StepTime', step_time, current_step)
if current_step % c.save_step == 0: if current_step % c.save_step == 0:
@ -345,6 +361,7 @@ def main(args):
c.r) c.r)
optimizer = optim.Adam(model.parameters(), lr=c.lr) optimizer = optim.Adam(model.parameters(), lr=c.lr)
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
criterion = L1LossMasked() criterion = L1LossMasked()
criterion_st = nn.BCELoss() criterion_st = nn.BCELoss()
@ -378,7 +395,7 @@ def main(args):
for epoch in range(0, c.epochs): for epoch in range(0, c.epochs):
train_loss, current_step = train( train_loss, current_step = train(
model, criterion, criterion_st, train_loader, optimizer, epoch) model, criterion, criterion_st, train_loader, optimizer, optimizer_st, epoch)
val_loss = evaluate(model, criterion, criterion_st, val_loader, current_step) val_loss = evaluate(model, criterion, criterion_st, val_loader, current_step)
best_loss = save_best_model(model, optimizer, val_loss, best_loss = save_best_model(model, optimizer, val_loss,
best_loss, OUT_PATH, best_loss, OUT_PATH,