mirror of https://github.com/coqui-ai/TTS.git
optional loss masking for stoptoken predictor
This commit is contained in:
parent
e49cc3bbcd
commit
fdaed45f58
|
@ -163,14 +163,20 @@ class BCELossMasked(nn.Module):
|
|||
"""
|
||||
# mask: (batch, max_len, 1)
|
||||
target.requires_grad = False
|
||||
mask = sequence_mask(sequence_length=length,
|
||||
max_len=target.size(1)).float()
|
||||
if length is not None:
|
||||
mask = sequence_mask(sequence_length=length,
|
||||
max_len=target.size(1)).float()
|
||||
x = x * mask
|
||||
target = target * mask
|
||||
num_items = mask.sum()
|
||||
else:
|
||||
num_items = torch.numel(x)
|
||||
loss = functional.binary_cross_entropy_with_logits(
|
||||
x * mask,
|
||||
target * mask,
|
||||
x,
|
||||
target,
|
||||
pos_weight=self.pos_weight,
|
||||
reduction='sum')
|
||||
loss = loss / mask.sum()
|
||||
loss = loss / num_items
|
||||
return loss
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue