tf lstm does not match torch lstm wrt bias vectors. So I avoid bias in LSTM as an easy solution.

This commit is contained in:
erogol 2020-04-28 18:16:37 +02:00
parent d282222553
commit 736f169cc9
1 changed files with 3 additions and 1 deletions

View File

@ -61,6 +61,7 @@ class Encoder(nn.Module):
int(output_input_dim / 2),
num_layers=1,
batch_first=True,
bias=False,
bidirectional=True)
self.rnn_state = None
@ -121,7 +122,8 @@ class Decoder(nn.Module):
bias=False)
self.attention_rnn = nn.LSTMCell(self.prenet_dim + input_dim,
self.query_dim)
self.query_dim,
bias=False)
self.attention = init_attn(attn_type=attn_type,
query_dim=self.query_dim,