diff --git a/layers/attention.py b/layers/attention.py index 9c63a85f..f598e182 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -13,8 +13,8 @@ class BahdanauAttention(nn.Module): def forward(self, annots, query): """ Shapes: - - query: (batch, 1, dim) or (batch, dim) - annots: (batch, max_time, dim) + - query: (batch, 1, dim) or (batch, dim) """ if query.dim() == 2: # insert time-axis for broadcasting