mirror of https://github.com/coqui-ai/TTS.git
optional loss masking for stoptoken predictor
This commit is contained in:
parent
e49cc3bbcd
commit
fdaed45f58
|
@ -163,14 +163,20 @@ class BCELossMasked(nn.Module):
|
||||||
"""
|
"""
|
||||||
# mask: (batch, max_len, 1)
|
# mask: (batch, max_len, 1)
|
||||||
target.requires_grad = False
|
target.requires_grad = False
|
||||||
mask = sequence_mask(sequence_length=length,
|
if length is not None:
|
||||||
max_len=target.size(1)).float()
|
mask = sequence_mask(sequence_length=length,
|
||||||
|
max_len=target.size(1)).float()
|
||||||
|
x = x * mask
|
||||||
|
target = target * mask
|
||||||
|
num_items = mask.sum()
|
||||||
|
else:
|
||||||
|
num_items = torch.numel(x)
|
||||||
loss = functional.binary_cross_entropy_with_logits(
|
loss = functional.binary_cross_entropy_with_logits(
|
||||||
x * mask,
|
x,
|
||||||
target * mask,
|
target,
|
||||||
pos_weight=self.pos_weight,
|
pos_weight=self.pos_weight,
|
||||||
reduction='sum')
|
reduction='sum')
|
||||||
loss = loss / mask.sum()
|
loss = loss / num_items
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue