mirror of https://github.com/coqui-ai/TTS.git
Update `PositionalEncoding`
This commit is contained in:
parent
076d0cb258
commit
29248536c9
|
@ -7,17 +7,23 @@ from torch import nn
|
|||
class PositionalEncoding(nn.Module):
|
||||
"""Sinusoidal positional encoding for non-recurrent neural networks.
|
||||
Implementation based on "Attention Is All You Need"
|
||||
|
||||
Args:
|
||||
channels (int): embedding size
|
||||
dropout (float): dropout parameter
|
||||
dropout_p (float): dropout rate applied to the output.
|
||||
max_len (int): maximum sequence length.
|
||||
use_scale (bool): whether to use a learnable scaling coefficient.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, dropout_p=0.0, max_len=5000):
|
||||
def __init__(self, channels, dropout_p=0.0, max_len=5000, use_scale=False):
|
||||
super().__init__()
|
||||
if channels % 2 != 0:
|
||||
raise ValueError(
|
||||
"Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels)
|
||||
)
|
||||
self.use_scale = use_scale
|
||||
if use_scale:
|
||||
self.scale = torch.nn.Parameter(torch.ones(1))
|
||||
pe = torch.zeros(max_len, channels)
|
||||
position = torch.arange(0, max_len).unsqueeze(1)
|
||||
div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels)
|
||||
|
@ -49,9 +55,15 @@ class PositionalEncoding(nn.Module):
|
|||
pos_enc = self.pe[:, :, : x.size(2)] * mask
|
||||
else:
|
||||
pos_enc = self.pe[:, :, : x.size(2)]
|
||||
x = x + pos_enc
|
||||
if self.use_scale:
|
||||
x = x + self.scale * pos_enc
|
||||
else:
|
||||
x = x + pos_enc
|
||||
else:
|
||||
x = x + self.pe[:, :, first_idx:last_idx]
|
||||
if self.use_scale:
|
||||
x = x + self.scale * self.pe[:, :, first_idx:last_idx]
|
||||
else:
|
||||
x = x + self.pe[:, :, first_idx:last_idx]
|
||||
if hasattr(self, "dropout"):
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
|
Loading…
Reference in New Issue