From 29248536c90db097716e87a035b1c2dbfcbc5563 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 3 Sep 2021 13:28:46 +0000 Subject: [PATCH] Update `PositionalEncoding` --- TTS/tts/layers/generic/pos_encoding.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/TTS/tts/layers/generic/pos_encoding.py b/TTS/tts/layers/generic/pos_encoding.py index 46a0b516..913add0d 100644 --- a/TTS/tts/layers/generic/pos_encoding.py +++ b/TTS/tts/layers/generic/pos_encoding.py @@ -7,17 +7,23 @@ from torch import nn class PositionalEncoding(nn.Module): """Sinusoidal positional encoding for non-recurrent neural networks. Implementation based on "Attention Is All You Need" + Args: channels (int): embedding size - dropout (float): dropout parameter + dropout_p (float): dropout rate applied to the output. + max_len (int): maximum sequence length. + use_scale (bool): whether to use a learnable scaling coefficient. """ - def __init__(self, channels, dropout_p=0.0, max_len=5000): + def __init__(self, channels, dropout_p=0.0, max_len=5000, use_scale=False): super().__init__() if channels % 2 != 0: raise ValueError( "Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels) ) + self.use_scale = use_scale + if use_scale: + self.scale = torch.nn.Parameter(torch.ones(1)) pe = torch.zeros(max_len, channels) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels) @@ -49,9 +55,15 @@ class PositionalEncoding(nn.Module): pos_enc = self.pe[:, :, : x.size(2)] * mask else: pos_enc = self.pe[:, :, : x.size(2)] - x = x + pos_enc + if self.use_scale: + x = x + self.scale * pos_enc + else: + x = x + pos_enc else: - x = x + self.pe[:, :, first_idx:last_idx] + if self.use_scale: + x = x + self.scale * self.pe[:, :, first_idx:last_idx] + else: + x = x + self.pe[:, :, first_idx:last_idx] if hasattr(self, "dropout"): x = self.dropout(x) return x