mirror of https://github.com/coqui-ai/TTS.git
binary measure
This commit is contained in:
parent
a510baa79c
commit
560b094f3a
|
@ -1,11 +1,18 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
def alignment_diagonal_score(alignments):
|
|
||||||
|
def alignment_diagonal_score(alignments, binary=False):
|
||||||
"""
|
"""
|
||||||
Compute how diagonal alignment predictions are. It is useful
|
Compute how diagonal alignment predictions are. It is useful
|
||||||
to measure the alignment consistency of a model
|
to measure the alignment consistency of a model
|
||||||
Args:
|
Args:
|
||||||
alignments (torch.Tensor): batch of alignments.
|
alignments (torch.Tensor): batch of alignments.
|
||||||
|
binary (bool): if True, ignore scores and consider attention
|
||||||
|
as a binary mask.
|
||||||
Shape:
|
Shape:
|
||||||
alignments : batch x decoder_steps x encoder_steps
|
alignments : batch x decoder_steps x encoder_steps
|
||||||
"""
|
"""
|
||||||
return alignments.max(dim=1)[0].mean(dim=1).mean(dim=0).item()
|
if binary:
|
||||||
|
return torch.clamp(alignments.max(dim=1)[0], max=1.0).mean(dim=1).mean(dim=0).item()
|
||||||
|
else:
|
||||||
|
return alignments.max(dim=1)[0].mean(dim=1).mean(dim=0).item()
|
||||||
|
|
Loading…
Reference in New Issue