Perform stop token prediction to stop the model.

This commit is contained in:
Eren Golge 2018-05-03 05:56:06 -07:00
parent c54600e0d8
commit 192ff3b89e
1 changed files with 2 additions and 2 deletions

View File

@ -315,7 +315,7 @@ class Decoder(nn.Module):
if t >= T_decoder:
break
else:
if t > 1 and is_end_of_frames(output.view(self.r, -1), alignment, self.eps):
if t > 1 and stop_token.sum().item() > 0.7:
break
elif t > self.max_decoder_steps:
print(" !! Decoder stopped with 'max_decoder_steps'. \
@ -326,7 +326,7 @@ class Decoder(nn.Module):
alignments = torch.stack(alignments).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
return outputs, alignments, stop_tokens
return outputs, alignments, stop_tokens.squeeze()
def is_end_of_frames(output, alignment, eps=0.01): # 0.2