Update `PositionalEncoding`

This commit is contained in:
Eren Gölge 2021-09-03 13:28:46 +00:00
parent 076d0cb258
commit 29248536c9
1 changed files with 16 additions and 4 deletions

View File

@ -7,17 +7,23 @@ from torch import nn
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding for non-recurrent neural networks.
Implementation based on "Attention Is All You Need"
Args:
channels (int): embedding size
dropout (float): dropout parameter
dropout_p (float): dropout rate applied to the output.
max_len (int): maximum sequence length.
use_scale (bool): whether to use a learnable scaling coefficient.
"""
def __init__(self, channels, dropout_p=0.0, max_len=5000):
def __init__(self, channels, dropout_p=0.0, max_len=5000, use_scale=False):
super().__init__()
if channels % 2 != 0:
raise ValueError(
"Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels)
)
self.use_scale = use_scale
if use_scale:
self.scale = torch.nn.Parameter(torch.ones(1))
pe = torch.zeros(max_len, channels)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels)
@ -49,9 +55,15 @@ class PositionalEncoding(nn.Module):
pos_enc = self.pe[:, :, : x.size(2)] * mask
else:
pos_enc = self.pe[:, :, : x.size(2)]
x = x + pos_enc
if self.use_scale:
x = x + self.scale * pos_enc
else:
x = x + pos_enc
else:
x = x + self.pe[:, :, first_idx:last_idx]
if self.use_scale:
x = x + self.scale * self.pe[:, :, first_idx:last_idx]
else:
x = x + self.pe[:, :, first_idx:last_idx]
if hasattr(self, "dropout"):
x = self.dropout(x)
return x