measure update

This commit is contained in:
root 2020-01-14 10:41:41 +01:00
parent 560b094f3a
commit 94c6cb8e07
1 changed files with 3 additions and 3 deletions

View File

@ -12,7 +12,7 @@ def alignment_diagonal_score(alignments, binary=False):
Shape:
alignments : batch x decoder_steps x encoder_steps
"""
maxs = alignments.max(dim=1)[0]
if binary:
return torch.clamp(alignments.max(dim=1)[0], max=1.0).mean(dim=1).mean(dim=0).item()
else:
return alignments.max(dim=1)[0].mean(dim=1).mean(dim=0).item()
maxs[maxs > 0] = 1
return maxs.mean(dim=1).mean(dim=0).item()