mirror of https://github.com/coqui-ai/TTS.git
positional encoding masking for SS
This commit is contained in:
parent
7c95b11fe8
commit
ac5c9217d1
|
@ -115,7 +115,7 @@
|
|||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"num_val_loader_workers": 8, // number of evaluation data loader processes.
|
||||
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
|
||||
"batch_group_size": 4, //Number of batches to shuffle after bucketing.
|
||||
"min_seq_len": 2, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 300, // DATASET-RELATED: maximum text length
|
||||
"compute_f0": false, // compute f0 values in data-loader
|
||||
|
|
|
@ -33,28 +33,35 @@ class PositionalEncoding(nn.Module):
|
|||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x, step=None):
|
||||
def forward(self, x, mask=None, first_idx=None, last_idx=None):
|
||||
"""Embed inputs.
|
||||
Args:
|
||||
x (FloatTensor): Sequence of word vectors
|
||||
``(seq_len, batch_size, self.dim)``
|
||||
step (int or NoneType): If stepwise (``seq_len = 1``), use
|
||||
the encoding for this position.
|
||||
mask (FloatTensor): Sequence mask.
|
||||
first_idx (int or NoneType): starting index for taking a
|
||||
certain part of the embeddings.
|
||||
last_idx (int or NoneType): ending index for taking a
|
||||
certain part of the embeddings.
|
||||
|
||||
Shapes:
|
||||
x: B x C x T
|
||||
"""
|
||||
|
||||
x = x * math.sqrt(self.dim)
|
||||
if step is None:
|
||||
if first_idx is None:
|
||||
if self.pe.size(2) < x.size(2):
|
||||
raise RuntimeError(
|
||||
f"Sequence is {x.size(2)} but PositionalEncoding is"
|
||||
f" limited to {self.pe.size(2)}. See max_len argument."
|
||||
)
|
||||
x = x + self.pe[:, : ,:x.size(2)]
|
||||
if mask is not None:
|
||||
pos_enc = (self.pe[:, : ,:x.size(2)] * mask)
|
||||
else:
|
||||
pos_enc = self.pe[:, :, :x.size(2)]
|
||||
x = x + pos_enc
|
||||
else:
|
||||
x = x + self.pe[:, :, step]
|
||||
x = x + self.pe[:, :, first_idx:last_idx]
|
||||
if hasattr(self, 'dropout'):
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
|
|
@ -81,7 +81,7 @@ class SpeedySpeech(nn.Module):
|
|||
|
||||
# positional encoding
|
||||
if hasattr(self, 'pos_encoder'):
|
||||
o_en_ex = self.pos_encoder(o_en_ex)
|
||||
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
||||
|
||||
# decoder pass
|
||||
o_de = self.decoder(o_en_ex, y_mask)
|
||||
|
|
Loading…
Reference in New Issue