Commenting the attention code

This commit is contained in:
Eren Golge 2018-05-23 06:16:39 -07:00
parent 7acf4eab94
commit d8c460442a
1 changed files with 9 additions and 0 deletions

View File

@ -70,6 +70,15 @@ class LocationSensitiveAttention(nn.Module):
class AttentionRNN(nn.Module):
def __init__(self, out_dim, annot_dim, memory_dim, align_model):
r"""
General Attention RNN wrapper
Args:
out_dim (int): context vector feature dimension.
annot_dim (int): annotation vector feature dimension.
memory_dim (int): memory vector (decoder autogression) feature dimension.
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
"""
super(AttentionRNN, self).__init__()
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
# pick bahdanau or location sensitive attention