From cc9bfe96af4b13700dde27d5650e053853285470 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 30 Apr 2018 05:47:14 -0700 Subject: [PATCH] add stop loss --- train.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/train.py b/train.py index 151cbbeb..5805ea5c 100644 --- a/train.py +++ b/train.py @@ -62,11 +62,12 @@ else: print(" > Priority freq. is disabled.") -def train(model, criterion, data_loader, optimizer, epoch): +def train(model, criterion, criterion_st, data_loader, optimizer, epoch): model = model.train() epoch_time = 0 avg_linear_loss = 0 avg_mel_loss = 0 + avg_stop_loss = 0 avg_attn_loss = 0 print(" | > Epoch {}/{}".format(epoch, c.epochs)) @@ -108,18 +109,19 @@ def train(model, criterion, data_loader, optimizer, epoch): mk = mk_decay(c.mk, c.epochs, epoch) # forward pass - mel_output, linear_output, alignments =\ + mel_output, linear_output, alignments, stop_tokens =\ model.forward(text_input, mel_spec) # loss computation mel_loss = criterion(mel_output, mel_spec, mel_lengths) linear_loss = criterion(linear_output, linear_spec, mel_lengths) + stop_loss = criterion_st(stop_tokens, stop_targets) if c.priority_freq: linear_loss = 0.5 * linear_loss\ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_spec[:, :, :n_priority_freq], mel_lengths) - loss = mel_loss + linear_loss + loss = mel_loss + linear_loss + stop_loss if c.mk > 0.0: attention_loss = criterion(alignments, M, mel_lengths) loss += mk * attention_loss @@ -141,12 +143,14 @@ def train(model, criterion, data_loader, optimizer, epoch): progbar_display['total_loss'] = loss.item() progbar_display['linear_loss'] = linear_loss.item() progbar_display['mel_loss'] = mel_loss.item() + progbar_display['stop_loss'] = stop_loss.item() progbar_display['grad_norm'] = grad_norm.item() # update progbar.update(num_iter+1, values=list(progbar_display.items())) avg_linear_loss += linear_loss.item() avg_mel_loss += mel_loss.item() + avg_stop_loss += st_loss.item() # Plot Training Iter Stats tb.add_scalar('TrainIterLoss/TotalLoss', loss.item(), current_step) @@ -193,11 +197,13 @@ def train(model, criterion, data_loader, optimizer, epoch): avg_linear_loss /= (num_iter + 1) avg_mel_loss /= (num_iter + 1) - avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss /= (num_iter + 1) + avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss # Plot Training Epoch Stats tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step) tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step) + tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step) tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step) if c.mk > 0: avg_attn_loss /= (num_iter + 1) @@ -208,11 +214,12 @@ def train(model, criterion, data_loader, optimizer, epoch): return avg_linear_loss, current_step -def evaluate(model, criterion, data_loader, current_step): +def evaluate(model, criterion, criterion_st, data_loader, current_step): model = model.eval() epoch_time = 0 avg_linear_loss = 0 avg_mel_loss = 0 + avg_stop_loss = 0 print("\n | > Validation") progbar = Progbar(len(data_loader.dataset) / c.batch_size) @@ -236,18 +243,19 @@ def evaluate(model, criterion, data_loader, current_step): linear_spec = linear_spec.cuda() # forward pass - mel_output, linear_output, alignments =\ + mel_output, linear_output, alignments, stop_tokens =\ model.forward(text_input, mel_spec) # loss computation mel_loss = criterion(mel_output, mel_spec, mel_lengths) linear_loss = criterion(linear_output, linear_spec, mel_lengths) + stop_loss = criterion_st(stop_tokens, stop_targets) if c.priority_freq: linear_loss = 0.5 * linear_loss\ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_spec[:, :, :n_priority_freq], mel_lengths) - loss = mel_loss + linear_loss + loss = mel_loss + linear_loss + stop_loss step_time = time.time() - start_time epoch_time += step_time @@ -256,11 +264,13 @@ def evaluate(model, criterion, data_loader, current_step): progbar.update(num_iter+1, values=[('total_loss', loss.item()), ('linear_loss', linear_loss.item()), + ('stop_loss', stop_loss.item()), ('mel_loss', mel_loss.item())]) sys.stdout.flush() avg_linear_loss += linear_loss.item() avg_mel_loss += mel_loss.item() + avg_stop_loss += stop_loss.item() # Diagnostic visualizations idx = np.random.randint(mel_spec.shape[0]) @@ -292,12 +302,14 @@ def evaluate(model, criterion, data_loader, current_step): # compute average losses avg_linear_loss /= (num_iter + 1) avg_mel_loss /= (num_iter + 1) - avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss /= (num_iter + 1) + avg_total_loss = avg_mel_loss + avg_linear_loss + stop_loss # Plot Learning Stats tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step) tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step) tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step) + tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step) return avg_linear_loss @@ -355,8 +367,10 @@ def main(args): if use_cuda: criterion = L1LossMasked().cuda() + criterion_st = nn.BCELoss().cuda() else: criterion = L1LossMasked() + criterion_st = nn.BCELoss() if args.restore_path: checkpoint = torch.load(args.restore_path) @@ -392,8 +406,8 @@ def main(args): for epoch in range(0, c.epochs): train_loss, current_step = train( - model, criterion, train_loader, optimizer, epoch) - val_loss = evaluate(model, criterion, val_loader, current_step) + model, criterion, criterion_st, train_loader, optimizer, epoch) + val_loss = evaluate(model, criterion, criterion_st, val_loader, current_step) best_loss = save_best_model(model, optimizer, val_loss, best_loss, OUT_PATH, current_step, epoch)