diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index efd0c2cb..f26cb884 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -163,14 +163,20 @@ class BCELossMasked(nn.Module): """ # mask: (batch, max_len, 1) target.requires_grad = False - mask = sequence_mask(sequence_length=length, - max_len=target.size(1)).float() + if length is not None: + mask = sequence_mask(sequence_length=length, + max_len=target.size(1)).float() + x = x * mask + target = target * mask + num_items = mask.sum() + else: + num_items = torch.numel(x) loss = functional.binary_cross_entropy_with_logits( - x * mask, - target * mask, + x, + target, pos_weight=self.pos_weight, reduction='sum') - loss = loss / mask.sum() + loss = loss / num_items return loss