Force alignment of forward attention

This commit is contained in:
Eren Golge 2019-05-24 13:18:18 +02:00
parent f08be38b9f
commit 3a4a3e571a
1 changed files with 5 additions and 1 deletions

View File

@ -219,6 +219,10 @@ class Attention(nn.Module):
# forward attention
prev_alpha = F.pad(self.alpha[:, :-1].clone(),
(1, 0, 0, 0)).to(inputs.device)
# force incremental alignment
if not self.training:
val, n = prev_alpha.max(1)
alignment[:, n+2 :] = 0
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
self.u * prev_alpha) + 1e-8) * alignment
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
@ -500,7 +504,7 @@ class Decoder(nn.Module):
stop_flags[2] = t > inputs.shape[1] * 2
if all(stop_flags):
stop_count += 1
if stop_count > 5:
if stop_count > 10:
break
elif len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")