From 560b094f3a488150a3431f99f2840fc923f725df Mon Sep 17 00:00:00 2001 From: root Date: Tue, 14 Jan 2020 02:27:09 +0100 Subject: [PATCH] binary measure --- utils/measures.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/utils/measures.py b/utils/measures.py index a76a2225..783d66d1 100644 --- a/utils/measures.py +++ b/utils/measures.py @@ -1,11 +1,18 @@ +import torch -def alignment_diagonal_score(alignments): + +def alignment_diagonal_score(alignments, binary=False): """ Compute how diagonal alignment predictions are. It is useful to measure the alignment consistency of a model Args: alignments (torch.Tensor): batch of alignments. + binary (bool): if True, ignore scores and consider attention + as a binary mask. Shape: alignments : batch x decoder_steps x encoder_steps """ - return alignments.max(dim=1)[0].mean(dim=1).mean(dim=0).item() + if binary: + return torch.clamp(alignments.max(dim=1)[0], max=1.0).mean(dim=1).mean(dim=0).item() + else: + return alignments.max(dim=1)[0].mean(dim=1).mean(dim=0).item()