positional encoding masking for SS

This commit is contained in:
erogol 2020-12-29 15:16:06 +01:00
parent 7c95b11fe8
commit ac5c9217d1
3 changed files with 15 additions and 8 deletions

View File

@ -115,7 +115,7 @@
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"num_val_loader_workers": 8, // number of evaluation data loader processes.
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
"batch_group_size": 4, //Number of batches to shuffle after bucketing.
"min_seq_len": 2, // DATASET-RELATED: minimum text length to use in training
"max_seq_len": 300, // DATASET-RELATED: maximum text length
"compute_f0": false, // compute f0 values in data-loader

View File

@ -33,28 +33,35 @@ class PositionalEncoding(nn.Module):
self.dropout = nn.Dropout(p=dropout)
self.dim = dim
def forward(self, x, step=None):
def forward(self, x, mask=None, first_idx=None, last_idx=None):
"""Embed inputs.
Args:
x (FloatTensor): Sequence of word vectors
``(seq_len, batch_size, self.dim)``
step (int or NoneType): If stepwise (``seq_len = 1``), use
the encoding for this position.
mask (FloatTensor): Sequence mask.
first_idx (int or NoneType): starting index for taking a
certain part of the embeddings.
last_idx (int or NoneType): ending index for taking a
certain part of the embeddings.
Shapes:
x: B x C x T
"""
x = x * math.sqrt(self.dim)
if step is None:
if first_idx is None:
if self.pe.size(2) < x.size(2):
raise RuntimeError(
f"Sequence is {x.size(2)} but PositionalEncoding is"
f" limited to {self.pe.size(2)}. See max_len argument."
)
x = x + self.pe[:, : ,:x.size(2)]
if mask is not None:
pos_enc = (self.pe[:, : ,:x.size(2)] * mask)
else:
pos_enc = self.pe[:, :, :x.size(2)]
x = x + pos_enc
else:
x = x + self.pe[:, :, step]
x = x + self.pe[:, :, first_idx:last_idx]
if hasattr(self, 'dropout'):
x = self.dropout(x)
return x

View File

@ -81,7 +81,7 @@ class SpeedySpeech(nn.Module):
# positional encoding
if hasattr(self, 'pos_encoder'):
o_en_ex = self.pos_encoder(o_en_ex)
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
# decoder pass
o_de = self.decoder(o_en_ex, y_mask)