From 94c6cb8e0777e0d688ac596ca6c8b23e76331070 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 14 Jan 2020 10:41:41 +0100 Subject: [PATCH] measure update --- utils/measures.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/measures.py b/utils/measures.py index 783d66d1..01d25695 100644 --- a/utils/measures.py +++ b/utils/measures.py @@ -12,7 +12,7 @@ def alignment_diagonal_score(alignments, binary=False): Shape: alignments : batch x decoder_steps x encoder_steps """ + maxs = alignments.max(dim=1)[0] if binary: - return torch.clamp(alignments.max(dim=1)[0], max=1.0).mean(dim=1).mean(dim=0).item() - else: - return alignments.max(dim=1)[0].mean(dim=1).mean(dim=0).item() + maxs[maxs > 0] = 1 + return maxs.mean(dim=1).mean(dim=0).item()