mirror of https://github.com/coqui-ai/TTS.git
Fix BCE loss issue
This commit is contained in:
parent
d46fbc240c
commit
97cce10250
|
@ -1,3 +1,4 @@
|
|||
from importlib.metadata import requires
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
@ -165,7 +166,7 @@ class BCELossMasked(nn.Module):
|
|||
|
||||
def __init__(self, pos_weight: float = None):
|
||||
super().__init__()
|
||||
self.pos_weight = torch.tensor([pos_weight])
|
||||
self.pos_weight = nn.Parameter(torch.tensor([pos_weight]), requires_grad=False)
|
||||
|
||||
def forward(self, x, target, length):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue