bug fix for stop token prediciton

This commit is contained in:
Eren Golge 2019-03-12 00:19:03 +01:00
parent 527567d7ce
commit 3128378bdf
1 changed files with 1 additions and 2 deletions

View File

@ -476,6 +476,7 @@ class Decoder(nn.Module):
new_memory = outputs[-1] new_memory = outputs[-1]
self._update_memory_queue(new_memory) self._update_memory_queue(new_memory)
output, stop_token, attention = self.decode(inputs, t, None) output, stop_token, attention = self.decode(inputs, t, None)
stop_token = torch.sigmoid(stop_token.data)
outputs += [output] outputs += [output]
attentions += [attention] attentions += [attention]
stop_tokens += [stop_token] stop_tokens += [stop_token]
@ -499,12 +500,10 @@ class StopNet(nn.Module):
super(StopNet, self).__init__() super(StopNet, self).__init__()
self.dropout = nn.Dropout(0.1) self.dropout = nn.Dropout(0.1)
self.linear = nn.Linear(in_features, 1) self.linear = nn.Linear(in_features, 1)
self.sigmoid = nn.Sigmoid()
torch.nn.init.xavier_uniform_( torch.nn.init.xavier_uniform_(
self.linear.weight, gain=torch.nn.init.calculate_gain('linear')) self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
def forward(self, inputs): def forward(self, inputs):
outputs = self.dropout(inputs) outputs = self.dropout(inputs)
outputs = self.linear(outputs) outputs = self.linear(outputs)
outputs = self.sigmoid(outputs)
return outputs return outputs