bug fix for losses

This commit is contained in:
root 2020-01-15 23:17:55 +01:00
parent 6fd61e82b0
commit 3af989643b
1 changed files with 2 additions and 0 deletions

View File

@ -37,6 +37,7 @@ class L1LossMasked(nn.Module):
x * mask, target * mask, reduction='none')
loss = loss.mul(out_weights.cuda()).sum()
else:
mask = mask.expand_as(x)
loss = functional.l1_loss(
x * mask, target * mask, reduction='sum')
loss = loss / mask.sum()
@ -75,6 +76,7 @@ class MSELossMasked(nn.Module):
x * mask, target * mask, reduction='none')
loss = loss.mul(out_weights.cuda()).sum()
else:
mask = mask.expand_as(x)
loss = functional.mse_loss(
x * mask, target * mask, reduction='sum')
loss = loss / mask.sum()