Pass mask instead of length to model

This commit is contained in:
Eren 2018-08-10 17:43:45 +02:00
parent 2196ce9eba
commit e0bce1d2d1
3 changed files with 9 additions and 16 deletions

View File

@ -105,13 +105,7 @@ class AttentionRNNCell(nn.Module):
'b' (Bahdanau) or 'ls' (Location Sensitive)."
.format(align_model))
def forward(self,
memory,
context,
rnn_state,
annots,
atten,
annot_lens=None):
def forward(self, memory, context, rnn_state, annots, atten, mask):
"""
Shapes:
- memory: (batch, 1, dim) or (batch, dim)
@ -119,7 +113,7 @@ class AttentionRNNCell(nn.Module):
- rnn_state: (batch, out_dim)
- annots: (batch, max_time, annot_dim)
- atten: (batch, 2, max_time)
- annot_lens: (batch,)
- mask: (batch,)
"""
# Concat input query and previous context context
rnn_input = torch.cat((memory, context), -1)
@ -133,8 +127,7 @@ class AttentionRNNCell(nn.Module):
alignment = self.alignment_model(annots, rnn_output)
else:
alignment = self.alignment_model(annots, rnn_output, atten)
if annot_lens is not None:
mask = sequence_mask(annot_lens)
if mask is not None:
mask = mask.view(memory.size(0), -1)
alignment.masked_fill_(1 - mask, -float("inf"))
# Normalize context weight

View File

@ -189,7 +189,7 @@ class CBHG(nn.Module):
x = highway(x)
# (B, T_in, hid_features*2)
# TODO: replace GRU with convolution as in Deep Voice 3
# self.gru.flatten_parameters()
self.gru.flatten_parameters()
outputs, _ = self.gru(x)
return outputs
@ -268,7 +268,7 @@ class Decoder(nn.Module):
self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.stopnet = StopNet(r, memory_dim)
def forward(self, inputs, memory=None, input_lens=None):
def forward(self, inputs, memory=None, mask=None):
"""
Decoder forward step.
@ -280,7 +280,7 @@ class Decoder(nn.Module):
memory (None): Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs. If None, it uses the last
output as the input.
input_lens (None): Time length of each input in batch.
mask (None): Attention mask for sequence padding.
Shapes:
- inputs: batch x time x encoder_out_dim
@ -332,7 +332,7 @@ class Decoder(nn.Module):
(attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention_cat, input_lens)
inputs, attention_cat, mask)
attention_cum += attention
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(

View File

@ -25,14 +25,14 @@ class Tacotron(nn.Module):
self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Linear(256, linear_dim)
def forward(self, characters, mel_specs=None, text_lens=None):
def forward(self, characters, mel_specs=None, mask=None):
B = characters.size(0)
inputs = self.embedding(characters)
# batch x time x dim
encoder_outputs = self.encoder(inputs)
# batch x time x dim*r
mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, text_lens)
encoder_outputs, mel_specs, mask)
# Reshape
# batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)