bug fix for stop token prediciton

This commit is contained in:
Eren Golge 2019-03-12 00:19:03 +01:00
parent 527567d7ce
commit 3128378bdf
1 changed files with 1 additions and 2 deletions

View File

@ -476,6 +476,7 @@ class Decoder(nn.Module):
new_memory = outputs[-1]
self._update_memory_queue(new_memory)
output, stop_token, attention = self.decode(inputs, t, None)
stop_token = torch.sigmoid(stop_token.data)
outputs += [output]
attentions += [attention]
stop_tokens += [stop_token]
@ -499,12 +500,10 @@ class StopNet(nn.Module):
super(StopNet, self).__init__()
self.dropout = nn.Dropout(0.1)
self.linear = nn.Linear(in_features, 1)
self.sigmoid = nn.Sigmoid()
torch.nn.init.xavier_uniform_(
self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
def forward(self, inputs):
outputs = self.dropout(inputs)
outputs = self.linear(outputs)
outputs = self.sigmoid(outputs)
return outputs