enforce monotonic attention in forward attention y for batches

This commit is contained in:
Eren Golge 2019-05-28 14:28:32 +02:00
parent d905f6e795
commit 0b5a00d29e
1 changed files with 3 additions and 6 deletions

View File

@ -203,16 +203,13 @@ class Attention(nn.Module):
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) + alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
self.u * prev_alpha) + 1e-8) * alignment self.u * prev_alpha) + 1e-8) * alignment
# force incremental alignment - TODO: make configurable # force incremental alignment - TODO: make configurable
if not self.training and alignment.shape[0] == 1: if not self.training:
_, n = prev_alpha.max(1) _, n = prev_alpha.max(1)
val, n2 = alpha.max(1) val, n2 = alpha.max(1)
for b in range(alignment.shape[0]): for b in range(alignment.shape[0]):
alpha[b, n + 2:] = 0 alpha[b, n + 2:] = 0
alpha[b, :( alpha[b, :(n - 1)] = 0 # ignore all previous states to prevent repetition.
n - 1 alpha[b, (n - 2)] = 0.01 * val # smoothing factor for the prev step
)] = 0 # ignore all previous states to prevent repetition.
alpha[b, (
n - 2)] = 0.01 * val # smoothing factor for the prev step
# compute attention weights # compute attention weights
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1) self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
# compute context # compute context