mirror of https://github.com/coqui-ai/TTS.git
Update `PositionalEncoding`
This commit is contained in:
parent
076d0cb258
commit
29248536c9
|
@ -7,17 +7,23 @@ from torch import nn
|
||||||
class PositionalEncoding(nn.Module):
|
class PositionalEncoding(nn.Module):
|
||||||
"""Sinusoidal positional encoding for non-recurrent neural networks.
|
"""Sinusoidal positional encoding for non-recurrent neural networks.
|
||||||
Implementation based on "Attention Is All You Need"
|
Implementation based on "Attention Is All You Need"
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
channels (int): embedding size
|
channels (int): embedding size
|
||||||
dropout (float): dropout parameter
|
dropout_p (float): dropout rate applied to the output.
|
||||||
|
max_len (int): maximum sequence length.
|
||||||
|
use_scale (bool): whether to use a learnable scaling coefficient.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, dropout_p=0.0, max_len=5000):
|
def __init__(self, channels, dropout_p=0.0, max_len=5000, use_scale=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if channels % 2 != 0:
|
if channels % 2 != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels)
|
"Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels)
|
||||||
)
|
)
|
||||||
|
self.use_scale = use_scale
|
||||||
|
if use_scale:
|
||||||
|
self.scale = torch.nn.Parameter(torch.ones(1))
|
||||||
pe = torch.zeros(max_len, channels)
|
pe = torch.zeros(max_len, channels)
|
||||||
position = torch.arange(0, max_len).unsqueeze(1)
|
position = torch.arange(0, max_len).unsqueeze(1)
|
||||||
div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels)
|
div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels)
|
||||||
|
@ -49,7 +55,13 @@ class PositionalEncoding(nn.Module):
|
||||||
pos_enc = self.pe[:, :, : x.size(2)] * mask
|
pos_enc = self.pe[:, :, : x.size(2)] * mask
|
||||||
else:
|
else:
|
||||||
pos_enc = self.pe[:, :, : x.size(2)]
|
pos_enc = self.pe[:, :, : x.size(2)]
|
||||||
|
if self.use_scale:
|
||||||
|
x = x + self.scale * pos_enc
|
||||||
|
else:
|
||||||
x = x + pos_enc
|
x = x + pos_enc
|
||||||
|
else:
|
||||||
|
if self.use_scale:
|
||||||
|
x = x + self.scale * self.pe[:, :, first_idx:last_idx]
|
||||||
else:
|
else:
|
||||||
x = x + self.pe[:, :, first_idx:last_idx]
|
x = x + self.pe[:, :, first_idx:last_idx]
|
||||||
if hasattr(self, "dropout"):
|
if hasattr(self, "dropout"):
|
||||||
|
|
Loading…
Reference in New Issue