diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index b9a03af1..9933df6b 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -165,7 +165,7 @@ class BCELossMasked(nn.Module): def __init__(self, pos_weight: float = None): super().__init__() - self.pos_weight = torch.tensor([pos_weight]) + self.pos_weight = nn.Parameter(torch.tensor([pos_weight]), requires_grad=False) def forward(self, x, target, length): """