print avg spec length in training

This commit is contained in:
Eren Golge 2018-12-11 15:08:02 +01:00
parent 0b35e8e949
commit 2ec55e2369
1 changed files with 3 additions and 2 deletions

View File

@ -52,6 +52,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
mel_lengths = data[4]
stop_targets = data[5]
avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float())
# set stop targets view, we predict a single stop token per r frames prediction
stop_targets = stop_targets.view(text_input.shape[0],
@ -130,10 +131,10 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
print(
" | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
"GradNormST:{:.5f} AvgTextLen:{:.1f} StepTime:{:.2f} LR:{:.6f}".format(
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} LR:{:.6f}".format(
num_iter, batch_n_iter, current_step, loss.item(),
linear_loss.item(), mel_loss.item(), stop_loss.item(),
grad_norm, grad_norm_st, avg_text_length, step_time, current_lr),
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr),
flush=True)
avg_linear_loss += linear_loss.item()