Pass mask instead of length to model

This commit is contained in:
Eren 2018-08-10 17:43:45 +02:00
parent 2196ce9eba
commit e0bce1d2d1
3 changed files with 9 additions and 16 deletions

View File

@ -105,13 +105,7 @@ class AttentionRNNCell(nn.Module):
'b' (Bahdanau) or 'ls' (Location Sensitive)." 'b' (Bahdanau) or 'ls' (Location Sensitive)."
.format(align_model)) .format(align_model))
def forward(self, def forward(self, memory, context, rnn_state, annots, atten, mask):
memory,
context,
rnn_state,
annots,
atten,
annot_lens=None):
""" """
Shapes: Shapes:
- memory: (batch, 1, dim) or (batch, dim) - memory: (batch, 1, dim) or (batch, dim)
@ -119,7 +113,7 @@ class AttentionRNNCell(nn.Module):
- rnn_state: (batch, out_dim) - rnn_state: (batch, out_dim)
- annots: (batch, max_time, annot_dim) - annots: (batch, max_time, annot_dim)
- atten: (batch, 2, max_time) - atten: (batch, 2, max_time)
- annot_lens: (batch,) - mask: (batch,)
""" """
# Concat input query and previous context context # Concat input query and previous context context
rnn_input = torch.cat((memory, context), -1) rnn_input = torch.cat((memory, context), -1)
@ -133,8 +127,7 @@ class AttentionRNNCell(nn.Module):
alignment = self.alignment_model(annots, rnn_output) alignment = self.alignment_model(annots, rnn_output)
else: else:
alignment = self.alignment_model(annots, rnn_output, atten) alignment = self.alignment_model(annots, rnn_output, atten)
if annot_lens is not None: if mask is not None:
mask = sequence_mask(annot_lens)
mask = mask.view(memory.size(0), -1) mask = mask.view(memory.size(0), -1)
alignment.masked_fill_(1 - mask, -float("inf")) alignment.masked_fill_(1 - mask, -float("inf"))
# Normalize context weight # Normalize context weight

View File

@ -189,7 +189,7 @@ class CBHG(nn.Module):
x = highway(x) x = highway(x)
# (B, T_in, hid_features*2) # (B, T_in, hid_features*2)
# TODO: replace GRU with convolution as in Deep Voice 3 # TODO: replace GRU with convolution as in Deep Voice 3
# self.gru.flatten_parameters() self.gru.flatten_parameters()
outputs, _ = self.gru(x) outputs, _ = self.gru(x)
return outputs return outputs
@ -268,7 +268,7 @@ class Decoder(nn.Module):
self.proj_to_mel = nn.Linear(256, memory_dim * r) self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.stopnet = StopNet(r, memory_dim) self.stopnet = StopNet(r, memory_dim)
def forward(self, inputs, memory=None, input_lens=None): def forward(self, inputs, memory=None, mask=None):
""" """
Decoder forward step. Decoder forward step.
@ -280,7 +280,7 @@ class Decoder(nn.Module):
memory (None): Decoder memory (autoregression. If None (at eval-time), memory (None): Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs. If None, it uses the last decoder outputs are used as decoder inputs. If None, it uses the last
output as the input. output as the input.
input_lens (None): Time length of each input in batch. mask (None): Attention mask for sequence padding.
Shapes: Shapes:
- inputs: batch x time x encoder_out_dim - inputs: batch x time x encoder_out_dim
@ -332,7 +332,7 @@ class Decoder(nn.Module):
(attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1) (attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention_cat, input_lens) inputs, attention_cat, mask)
attention_cum += attention attention_cum += attention
# Concat RNN output and attention context vector # Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in( decoder_input = self.project_to_decoder_in(

View File

@ -25,14 +25,14 @@ class Tacotron(nn.Module):
self.postnet = PostCBHG(mel_dim) self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Linear(256, linear_dim) self.last_linear = nn.Linear(256, linear_dim)
def forward(self, characters, mel_specs=None, text_lens=None): def forward(self, characters, mel_specs=None, mask=None):
B = characters.size(0) B = characters.size(0)
inputs = self.embedding(characters) inputs = self.embedding(characters)
# batch x time x dim # batch x time x dim
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
# batch x time x dim*r # batch x time x dim*r
mel_outputs, alignments, stop_tokens = self.decoder( mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, text_lens) encoder_outputs, mel_specs, mask)
# Reshape # Reshape
# batch x time x dim # batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim) mel_outputs = mel_outputs.view(B, -1, self.mel_dim)