Fix BCELoss adressing #1192

This commit is contained in:
Eren G??lge 2022-07-12 14:11:34 +02:00
parent eefd482f51
commit a6f73a18cb
3 changed files with 46 additions and 13 deletions

View File

@ -53,7 +53,7 @@ class TacotronConfig(BaseTTSConfig):
enable /disable the Stopnet that predicts the end of the decoder sequence. Defaults to True.
stopnet_pos_weight (float):
Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with
datasets with longer sentences. Defaults to 10.
datasets with longer sentences. Defaults to 0.2.
max_decoder_steps (int):
Max number of steps allowed for the decoder. Defaults to 50.
encoder_in_features (int):
@ -161,7 +161,7 @@ class TacotronConfig(BaseTTSConfig):
prenet_dropout_at_inference: bool = False
stopnet: bool = True
separate_stopnet: bool = True
stopnet_pos_weight: float = 10.0
stopnet_pos_weight: float = 0.2
max_decoder_steps: int = 500
encoder_in_features: int = 256
decoder_in_features: int = 256

View File

@ -147,9 +147,6 @@ class AttentionEntropyLoss(nn.Module):
"""
Forces attention to be more decisive by penalizing
soft attention weights
TODO: arguments
TODO: unit_test
"""
entropy = torch.distributions.Categorical(probs=align).entropy()
loss = (entropy / np.log(align.shape[1])).mean()
@ -157,9 +154,9 @@ class AttentionEntropyLoss(nn.Module):
class BCELossMasked(nn.Module):
def __init__(self, pos_weight):
def __init__(self, pos_weight:float=None):
super().__init__()
self.pos_weight = pos_weight
self.pos_weight = torch.tensor([pos_weight])
def forward(self, x, target, length):
"""
@ -179,16 +176,15 @@ class BCELossMasked(nn.Module):
Returns:
loss: An average loss value in range [0, 1] masked by the length.
"""
# mask: (batch, max_len, 1)
target.requires_grad = False
if length is not None:
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float()
x = x * mask
target = target * mask
# mask: (batch, max_len, 1)
mask = sequence_mask(sequence_length=length, max_len=target.size(1))
num_items = mask.sum()
loss = functional.binary_cross_entropy_with_logits(x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum")
else:
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
num_items = torch.numel(x)
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
loss = loss / num_items
return loss

View File

@ -1,8 +1,9 @@
import unittest
import torch as T
from torch.nn import functional
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.layers.losses import L1LossMasked, SSIMLoss, MSELossMasked
from TTS.tts.layers.losses import L1LossMasked, SSIMLoss, MSELossMasked, BCELossMasked
class L1LossMaskedTests(unittest.TestCase):
@ -200,3 +201,39 @@ class SSIMLossTests(unittest.TestCase):
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
assert output.item() == 0, "0 vs {}".format(output.item())
class BCELossTest(unittest.TestCase):
def test_in_out(self): # pylint: disable=no-self-use
layer = BCELossMasked(pos_weight=5.0)
length = T.tensor([95])
mask = sequence_mask(length, 100)
pos_weight = T.tensor([5.0])
target = 1. - sequence_mask(length - 1, 100).float() # [0, 0, .... 1, 1] where the first 1 is the last mel frame
true_x = target * 200 - 100 # creates logits of [-100, -100, ... 100, 100] corresponding to target
zero_x = T.zeros(target.shape) - 100. # simulate logits if it never stops decoding
early_x = -200. * sequence_mask(length - 3, 100).float() + 100. # simulate logits on early stopping
late_x = -200. * sequence_mask(length + 1, 100).float() + 100. # simulate logits on late stopping
loss = layer(true_x, target, length)
self.assertEqual(loss.item(), 0.0)
loss = layer(early_x, target, length)
self.assertAlmostEqual(loss.item(), 2.1053, places=4)
loss = layer(late_x, target, length)
self.assertAlmostEqual(loss.item(), 5.2632, places=4)
loss = layer(zero_x, target, length)
self.assertAlmostEqual(loss.item(), 5.2632, places=4)
# pos_weight should be < 1 to penalize early stopping
layer = BCELossMasked(pos_weight=0.2)
loss = layer(true_x, target, length)
self.assertEqual(loss.item(), 0.0)
# when pos_weight < 1 overweight the early stopping loss
loss_early = layer(early_x, target, length)
loss_late = layer(late_x, target, length)
self.assertGreater(loss_early.item(), loss_late.item())