mirror of https://github.com/coqui-ai/TTS.git
measure update
This commit is contained in:
parent
560b094f3a
commit
94c6cb8e07
|
@ -12,7 +12,7 @@ def alignment_diagonal_score(alignments, binary=False):
|
||||||
Shape:
|
Shape:
|
||||||
alignments : batch x decoder_steps x encoder_steps
|
alignments : batch x decoder_steps x encoder_steps
|
||||||
"""
|
"""
|
||||||
|
maxs = alignments.max(dim=1)[0]
|
||||||
if binary:
|
if binary:
|
||||||
return torch.clamp(alignments.max(dim=1)[0], max=1.0).mean(dim=1).mean(dim=0).item()
|
maxs[maxs > 0] = 1
|
||||||
else:
|
return maxs.mean(dim=1).mean(dim=0).item()
|
||||||
return alignments.max(dim=1)[0].mean(dim=1).mean(dim=0).item()
|
|
||||||
|
|
Loading…
Reference in New Issue