make attn guiding optional #2

This commit is contained in:
Eren Golge 2018-04-25 05:51:02 -07:00
parent f197ab1e28
commit 3c8ded5a18
1 changed files with 8 additions and 6 deletions

View File

@ -72,6 +72,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
print(" | > Epoch {}/{}".format(epoch, c.epochs))
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
progbar_display = {}
for num_iter, data in enumerate(data_loader):
start_time = time.time()
@ -127,6 +128,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
attention_loss = criterion(alignments, M, mel_lengths_var)
loss += mk * attention_loss
avg_attn_loss += attention_loss.data[0]
progbar_display['attn_loss'] = attention_loss.data[0]
# backpass and check the grad norm
loss.backward()
@ -139,14 +141,14 @@ def train(model, criterion, data_loader, optimizer, epoch):
step_time = time.time() - start_time
epoch_time += step_time
progbar_display['total_loss'] = loss.data[0]
progbar_display['linear_loss'] = linear_loss.data[0]
progbar_display['mel_loss'] = mel_loss.data[0]
progbar_display['grad_norm'] = grad_norm
# update
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
('linear_loss',
linear_loss.data[0]),
('mel_loss', mel_loss.data[0]),
('attn_loss', attention_loss.data[0]),
('grad_norm', grad_norm)])
progbar.update(num_iter+1, values=[tuple(progbar_display.iteritems())])
avg_linear_loss += linear_loss.data[0]
avg_mel_loss += mel_loss.data[0]