From 3af989643b16a36c182e614461bf301489028293 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 15 Jan 2020 23:17:55 +0100 Subject: [PATCH] bug fix for losses --- layers/losses.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/layers/losses.py b/layers/losses.py index b8b17c17..90d2ac80 100644 --- a/layers/losses.py +++ b/layers/losses.py @@ -37,6 +37,7 @@ class L1LossMasked(nn.Module): x * mask, target * mask, reduction='none') loss = loss.mul(out_weights.cuda()).sum() else: + mask = mask.expand_as(x) loss = functional.l1_loss( x * mask, target * mask, reduction='sum') loss = loss / mask.sum() @@ -75,6 +76,7 @@ class MSELossMasked(nn.Module): x * mask, target * mask, reduction='none') loss = loss.mul(out_weights.cuda()).sum() else: + mask = mask.expand_as(x) loss = functional.mse_loss( x * mask, target * mask, reduction='sum') loss = loss / mask.sum()