positional encoding masking for SS

This commit is contained in:
erogol 2020-12-29 15:16:06 +01:00
parent 7c95b11fe8
commit ac5c9217d1
3 changed files with 15 additions and 8 deletions

View File

@ -115,7 +115,7 @@
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"num_val_loader_workers": 8, // number of evaluation data loader processes. "num_val_loader_workers": 8, // number of evaluation data loader processes.
"batch_group_size": 0, //Number of batches to shuffle after bucketing. "batch_group_size": 4, //Number of batches to shuffle after bucketing.
"min_seq_len": 2, // DATASET-RELATED: minimum text length to use in training "min_seq_len": 2, // DATASET-RELATED: minimum text length to use in training
"max_seq_len": 300, // DATASET-RELATED: maximum text length "max_seq_len": 300, // DATASET-RELATED: maximum text length
"compute_f0": false, // compute f0 values in data-loader "compute_f0": false, // compute f0 values in data-loader

View File

@ -33,28 +33,35 @@ class PositionalEncoding(nn.Module):
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
self.dim = dim self.dim = dim
def forward(self, x, step=None): def forward(self, x, mask=None, first_idx=None, last_idx=None):
"""Embed inputs. """Embed inputs.
Args: Args:
x (FloatTensor): Sequence of word vectors x (FloatTensor): Sequence of word vectors
``(seq_len, batch_size, self.dim)`` ``(seq_len, batch_size, self.dim)``
step (int or NoneType): If stepwise (``seq_len = 1``), use mask (FloatTensor): Sequence mask.
the encoding for this position. first_idx (int or NoneType): starting index for taking a
certain part of the embeddings.
last_idx (int or NoneType): ending index for taking a
certain part of the embeddings.
Shapes: Shapes:
x: B x C x T x: B x C x T
""" """
x = x * math.sqrt(self.dim) x = x * math.sqrt(self.dim)
if step is None: if first_idx is None:
if self.pe.size(2) < x.size(2): if self.pe.size(2) < x.size(2):
raise RuntimeError( raise RuntimeError(
f"Sequence is {x.size(2)} but PositionalEncoding is" f"Sequence is {x.size(2)} but PositionalEncoding is"
f" limited to {self.pe.size(2)}. See max_len argument." f" limited to {self.pe.size(2)}. See max_len argument."
) )
x = x + self.pe[:, : ,:x.size(2)] if mask is not None:
pos_enc = (self.pe[:, : ,:x.size(2)] * mask)
else:
pos_enc = self.pe[:, :, :x.size(2)]
x = x + pos_enc
else: else:
x = x + self.pe[:, :, step] x = x + self.pe[:, :, first_idx:last_idx]
if hasattr(self, 'dropout'): if hasattr(self, 'dropout'):
x = self.dropout(x) x = self.dropout(x)
return x return x

View File

@ -81,7 +81,7 @@ class SpeedySpeech(nn.Module):
# positional encoding # positional encoding
if hasattr(self, 'pos_encoder'): if hasattr(self, 'pos_encoder'):
o_en_ex = self.pos_encoder(o_en_ex) o_en_ex = self.pos_encoder(o_en_ex, y_mask)
# decoder pass # decoder pass
o_de = self.decoder(o_en_ex, y_mask) o_de = self.decoder(o_en_ex, y_mask)