diff --git a/layers/common_layers.py b/layers/common_layers.py index b6f72bc1..f7b8e7ed 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -208,7 +208,7 @@ class Attention(nn.Module): _, n = prev_alpha.max(1) val, n2 = alpha.max(1) for b in range(alignment.shape[0]): - alpha[b, n[b] + 2:] = 0 + alpha[b, n[b] + 3:] = 0 alpha[b, :(n[b] - 1)] = 0 # ignore all previous states to prevent repetition. alpha[b, (n[b] - 2)] = 0.01 * val[b] # smoothing factor for the prev step # compute attention weights