make attn guiding optional #2

This commit is contained in:
Eren Golge 2018-04-25 05:45:32 -07:00
parent cbd24d9a42
commit f197ab1e28
2 changed files with 2 additions and 2 deletions

View File

@ -16,7 +16,7 @@
"batch_size": 32, "batch_size": 32,
"eval_batch_size":32, "eval_batch_size":32,
"r": 5, "r": 5,
"mk": 1, "mk": 0,
"griffin_lim_iters": 60, "griffin_lim_iters": 60,
"power": 1.2, "power": 1.2,

View File

@ -126,6 +126,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
if c.mk > 0.0: if c.mk > 0.0:
attention_loss = criterion(alignments, M, mel_lengths_var) attention_loss = criterion(alignments, M, mel_lengths_var)
loss += mk * attention_loss loss += mk * attention_loss
avg_attn_loss += attention_loss.data[0]
# backpass and check the grad norm # backpass and check the grad norm
loss.backward() loss.backward()
@ -148,7 +149,6 @@ def train(model, criterion, data_loader, optimizer, epoch):
('grad_norm', grad_norm)]) ('grad_norm', grad_norm)])
avg_linear_loss += linear_loss.data[0] avg_linear_loss += linear_loss.data[0]
avg_mel_loss += mel_loss.data[0] avg_mel_loss += mel_loss.data[0]
avg_attn_loss += attention_loss.data[0]
# Plot Training Iter Stats # Plot Training Iter Stats
tb.add_scalar('TrainIterLoss/TotalLoss', loss.data[0], current_step) tb.add_scalar('TrainIterLoss/TotalLoss', loss.data[0], current_step)