diff --git a/layers/attention.py b/layers/attention.py index 494d14dc..4f8af178 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -105,13 +105,7 @@ class AttentionRNNCell(nn.Module): 'b' (Bahdanau) or 'ls' (Location Sensitive)." .format(align_model)) - def forward(self, - memory, - context, - rnn_state, - annots, - atten, - annot_lens=None): + def forward(self, memory, context, rnn_state, annots, atten, mask): """ Shapes: - memory: (batch, 1, dim) or (batch, dim) @@ -119,7 +113,7 @@ class AttentionRNNCell(nn.Module): - rnn_state: (batch, out_dim) - annots: (batch, max_time, annot_dim) - atten: (batch, 2, max_time) - - annot_lens: (batch,) + - mask: (batch,) """ # Concat input query and previous context context rnn_input = torch.cat((memory, context), -1) @@ -133,8 +127,7 @@ class AttentionRNNCell(nn.Module): alignment = self.alignment_model(annots, rnn_output) else: alignment = self.alignment_model(annots, rnn_output, atten) - if annot_lens is not None: - mask = sequence_mask(annot_lens) + if mask is not None: mask = mask.view(memory.size(0), -1) alignment.masked_fill_(1 - mask, -float("inf")) # Normalize context weight diff --git a/layers/tacotron.py b/layers/tacotron.py index 08501ab4..cba6b1ae 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -189,7 +189,7 @@ class CBHG(nn.Module): x = highway(x) # (B, T_in, hid_features*2) # TODO: replace GRU with convolution as in Deep Voice 3 - # self.gru.flatten_parameters() + self.gru.flatten_parameters() outputs, _ = self.gru(x) return outputs @@ -268,7 +268,7 @@ class Decoder(nn.Module): self.proj_to_mel = nn.Linear(256, memory_dim * r) self.stopnet = StopNet(r, memory_dim) - def forward(self, inputs, memory=None, input_lens=None): + def forward(self, inputs, memory=None, mask=None): """ Decoder forward step. @@ -280,7 +280,7 @@ class Decoder(nn.Module): memory (None): Decoder memory (autoregression. If None (at eval-time), decoder outputs are used as decoder inputs. If None, it uses the last output as the input. - input_lens (None): Time length of each input in batch. + mask (None): Attention mask for sequence padding. Shapes: - inputs: batch x time x encoder_out_dim @@ -332,7 +332,7 @@ class Decoder(nn.Module): (attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1) attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( processed_memory, current_context_vec, attention_rnn_hidden, - inputs, attention_cat, input_lens) + inputs, attention_cat, mask) attention_cum += attention # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( diff --git a/models/tacotron.py b/models/tacotron.py index b4e4ed27..e25cb467 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -25,14 +25,14 @@ class Tacotron(nn.Module): self.postnet = PostCBHG(mel_dim) self.last_linear = nn.Linear(256, linear_dim) - def forward(self, characters, mel_specs=None, text_lens=None): + def forward(self, characters, mel_specs=None, mask=None): B = characters.size(0) inputs = self.embedding(characters) # batch x time x dim encoder_outputs = self.encoder(inputs) # batch x time x dim*r mel_outputs, alignments, stop_tokens = self.decoder( - encoder_outputs, mel_specs, text_lens) + encoder_outputs, mel_specs, mask) # Reshape # batch x time x dim mel_outputs = mel_outputs.view(B, -1, self.mel_dim)