measure update

This commit is contained in:
root 2020-01-14 10:41:41 +01:00
parent 560b094f3a
commit 94c6cb8e07
1 changed files with 3 additions and 3 deletions

View File

@ -12,7 +12,7 @@ def alignment_diagonal_score(alignments, binary=False):
Shape: Shape:
alignments : batch x decoder_steps x encoder_steps alignments : batch x decoder_steps x encoder_steps
""" """
maxs = alignments.max(dim=1)[0]
if binary: if binary:
return torch.clamp(alignments.max(dim=1)[0], max=1.0).mean(dim=1).mean(dim=0).item() maxs[maxs > 0] = 1
else: return maxs.mean(dim=1).mean(dim=0).item()
return alignments.max(dim=1)[0].mean(dim=1).mean(dim=0).item()