Use the last attention value as a threshold to stop decoding. since stoptoken prediction is not precies enough to synthsis at the right time.

This commit is contained in:
Eren Golge 2019-01-13 19:10:03 +01:00
parent ed1f648b83
commit 8969d59902
1 changed files with 1 additions and 1 deletions

View File

@ -434,7 +434,7 @@ class Decoder(nn.Module):
if t >= T_decoder:
break
else:
if t > inputs.shape[1] / 4 and stop_token > 0.6:
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6):
break
elif t > self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")