Fix BCE loss issue

This commit is contained in:
Eren G??lge 2022-08-15 10:42:13 +02:00
parent d46fbc240c
commit 97cce10250
1 changed files with 2 additions and 1 deletions

View File

@ -1,3 +1,4 @@
from importlib.metadata import requires
import math
import numpy as np
@ -165,7 +166,7 @@ class BCELossMasked(nn.Module):
def __init__(self, pos_weight: float = None):
super().__init__()
self.pos_weight = torch.tensor([pos_weight])
self.pos_weight = nn.Parameter(torch.tensor([pos_weight]), requires_grad=False)
def forward(self, x, target, length):
"""