plot float16 alignments

This commit is contained in:
erogol 2020-12-28 13:53:51 +01:00
parent 13c6665c92
commit e4680e1b99
1 changed files with 2 additions and 0 deletions

View File

@ -17,6 +17,8 @@ def plot_alignment(alignment,
alignment_ = alignment.detach().cpu().numpy().squeeze()
else:
alignment_ = alignment
alignment_ = alignment_.astype(
np.float32) if alignment_.dtype == np.float16 else alignment_
fig, ax = plt.subplots(figsize=fig_size)
im = ax.imshow(alignment_.T,
aspect='auto',