Pass mask instead of length to model

This commit is contained in:
Eren 2018-08-10 17:43:45 +02:00
parent e614cf1050
commit ba66e34ec7
3 changed files with 9 additions and 16 deletions

View File

@ -105,13 +105,7 @@ class AttentionRNNCell(nn.Module):
'b' (Bahdanau) or 'ls' (Location Sensitive)." 'b' (Bahdanau) or 'ls' (Location Sensitive)."
.format(align_model)) .format(align_model))
def forward(self, def forward(self, memory, context, rnn_state, annots, atten, mask):
memory,
context,
rnn_state,
annots,
atten,
annot_lens=None):
""" """
Shapes: Shapes:
- memory: (batch, 1, dim) or (batch, dim) - memory: (batch, 1, dim) or (batch, dim)
@ -119,7 +113,7 @@ class AttentionRNNCell(nn.Module):
- rnn_state: (batch, out_dim) - rnn_state: (batch, out_dim)
- annots: (batch, max_time, annot_dim) - annots: (batch, max_time, annot_dim)
- atten: (batch, 2, max_time) - atten: (batch, 2, max_time)
- annot_lens: (batch,) - mask: (batch,)
""" """
# Concat input query and previous context context # Concat input query and previous context context
rnn_input = torch.cat((memory, context), -1) rnn_input = torch.cat((memory, context), -1)
@ -133,8 +127,7 @@ class AttentionRNNCell(nn.Module):
alignment = self.alignment_model(annots, rnn_output) alignment = self.alignment_model(annots, rnn_output)
else: else:
alignment = self.alignment_model(annots, rnn_output, atten) alignment = self.alignment_model(annots, rnn_output, atten)
if annot_lens is not None: if mask is not None:
mask = sequence_mask(annot_lens)
mask = mask.view(memory.size(0), -1) mask = mask.view(memory.size(0), -1)
alignment.masked_fill_(1 - mask, -float("inf")) alignment.masked_fill_(1 - mask, -float("inf"))
# Normalize context weight # Normalize context weight

View File

@ -189,7 +189,7 @@ class CBHG(nn.Module):
x = highway(x) x = highway(x)
# (B, T_in, hid_features*2) # (B, T_in, hid_features*2)
# TODO: replace GRU with convolution as in Deep Voice 3 # TODO: replace GRU with convolution as in Deep Voice 3
# self.gru.flatten_parameters() self.gru.flatten_parameters()
outputs, _ = self.gru(x) outputs, _ = self.gru(x)
return outputs return outputs
@ -268,7 +268,7 @@ class Decoder(nn.Module):
self.proj_to_mel = nn.Linear(256, memory_dim * r) self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.stopnet = StopNet(r, memory_dim) self.stopnet = StopNet(r, memory_dim)
def forward(self, inputs, memory=None, input_lens=None): def forward(self, inputs, memory=None, mask=None):
""" """
Decoder forward step. Decoder forward step.
@ -280,7 +280,7 @@ class Decoder(nn.Module):
memory (None): Decoder memory (autoregression. If None (at eval-time), memory (None): Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs. If None, it uses the last decoder outputs are used as decoder inputs. If None, it uses the last
output as the input. output as the input.
input_lens (None): Time length of each input in batch. mask (None): Attention mask for sequence padding.
Shapes: Shapes:
- inputs: batch x time x encoder_out_dim - inputs: batch x time x encoder_out_dim
@ -332,7 +332,7 @@ class Decoder(nn.Module):
(attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1) (attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention_cat, input_lens) inputs, attention_cat, mask)
attention_cum += attention attention_cum += attention
# Concat RNN output and attention context vector # Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in( decoder_input = self.project_to_decoder_in(

View File

@ -25,14 +25,14 @@ class Tacotron(nn.Module):
self.postnet = PostCBHG(mel_dim) self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Linear(256, linear_dim) self.last_linear = nn.Linear(256, linear_dim)
def forward(self, characters, mel_specs=None, text_lens=None): def forward(self, characters, mel_specs=None, mask=None):
B = characters.size(0) B = characters.size(0)
inputs = self.embedding(characters) inputs = self.embedding(characters)
# batch x time x dim # batch x time x dim
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
# batch x time x dim*r # batch x time x dim*r
mel_outputs, alignments, stop_tokens = self.decoder( mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, text_lens) encoder_outputs, mel_specs, mask)
# Reshape # Reshape
# batch x time x dim # batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim) mel_outputs = mel_outputs.view(B, -1, self.mel_dim)