mirror of https://github.com/coqui-ai/TTS.git
Fix BCELoss adressing #1192
This commit is contained in:
parent
eefd482f51
commit
a6f73a18cb
|
@ -53,7 +53,7 @@ class TacotronConfig(BaseTTSConfig):
|
|||
enable /disable the Stopnet that predicts the end of the decoder sequence. Defaults to True.
|
||||
stopnet_pos_weight (float):
|
||||
Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with
|
||||
datasets with longer sentences. Defaults to 10.
|
||||
datasets with longer sentences. Defaults to 0.2.
|
||||
max_decoder_steps (int):
|
||||
Max number of steps allowed for the decoder. Defaults to 50.
|
||||
encoder_in_features (int):
|
||||
|
@ -161,7 +161,7 @@ class TacotronConfig(BaseTTSConfig):
|
|||
prenet_dropout_at_inference: bool = False
|
||||
stopnet: bool = True
|
||||
separate_stopnet: bool = True
|
||||
stopnet_pos_weight: float = 10.0
|
||||
stopnet_pos_weight: float = 0.2
|
||||
max_decoder_steps: int = 500
|
||||
encoder_in_features: int = 256
|
||||
decoder_in_features: int = 256
|
||||
|
|
|
@ -147,9 +147,6 @@ class AttentionEntropyLoss(nn.Module):
|
|||
"""
|
||||
Forces attention to be more decisive by penalizing
|
||||
soft attention weights
|
||||
|
||||
TODO: arguments
|
||||
TODO: unit_test
|
||||
"""
|
||||
entropy = torch.distributions.Categorical(probs=align).entropy()
|
||||
loss = (entropy / np.log(align.shape[1])).mean()
|
||||
|
@ -157,9 +154,9 @@ class AttentionEntropyLoss(nn.Module):
|
|||
|
||||
|
||||
class BCELossMasked(nn.Module):
|
||||
def __init__(self, pos_weight):
|
||||
def __init__(self, pos_weight:float=None):
|
||||
super().__init__()
|
||||
self.pos_weight = pos_weight
|
||||
self.pos_weight = torch.tensor([pos_weight])
|
||||
|
||||
def forward(self, x, target, length):
|
||||
"""
|
||||
|
@ -179,16 +176,15 @@ class BCELossMasked(nn.Module):
|
|||
Returns:
|
||||
loss: An average loss value in range [0, 1] masked by the length.
|
||||
"""
|
||||
# mask: (batch, max_len, 1)
|
||||
target.requires_grad = False
|
||||
if length is not None:
|
||||
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float()
|
||||
x = x * mask
|
||||
target = target * mask
|
||||
# mask: (batch, max_len, 1)
|
||||
mask = sequence_mask(sequence_length=length, max_len=target.size(1))
|
||||
num_items = mask.sum()
|
||||
loss = functional.binary_cross_entropy_with_logits(x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum")
|
||||
else:
|
||||
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
|
||||
num_items = torch.numel(x)
|
||||
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
|
||||
loss = loss / num_items
|
||||
return loss
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import unittest
|
||||
import torch as T
|
||||
from torch.nn import functional
|
||||
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.layers.losses import L1LossMasked, SSIMLoss, MSELossMasked
|
||||
from TTS.tts.layers.losses import L1LossMasked, SSIMLoss, MSELossMasked, BCELossMasked
|
||||
|
||||
|
||||
class L1LossMaskedTests(unittest.TestCase):
|
||||
|
@ -200,3 +201,39 @@ class SSIMLossTests(unittest.TestCase):
|
|||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||
|
||||
|
||||
class BCELossTest(unittest.TestCase):
|
||||
def test_in_out(self): # pylint: disable=no-self-use
|
||||
layer = BCELossMasked(pos_weight=5.0)
|
||||
|
||||
length = T.tensor([95])
|
||||
mask = sequence_mask(length, 100)
|
||||
pos_weight = T.tensor([5.0])
|
||||
target = 1. - sequence_mask(length - 1, 100).float() # [0, 0, .... 1, 1] where the first 1 is the last mel frame
|
||||
true_x = target * 200 - 100 # creates logits of [-100, -100, ... 100, 100] corresponding to target
|
||||
zero_x = T.zeros(target.shape) - 100. # simulate logits if it never stops decoding
|
||||
early_x = -200. * sequence_mask(length - 3, 100).float() + 100. # simulate logits on early stopping
|
||||
late_x = -200. * sequence_mask(length + 1, 100).float() + 100. # simulate logits on late stopping
|
||||
|
||||
loss = layer(true_x, target, length)
|
||||
self.assertEqual(loss.item(), 0.0)
|
||||
|
||||
loss = layer(early_x, target, length)
|
||||
self.assertAlmostEqual(loss.item(), 2.1053, places=4)
|
||||
|
||||
loss = layer(late_x, target, length)
|
||||
self.assertAlmostEqual(loss.item(), 5.2632, places=4)
|
||||
|
||||
loss = layer(zero_x, target, length)
|
||||
self.assertAlmostEqual(loss.item(), 5.2632, places=4)
|
||||
|
||||
# pos_weight should be < 1 to penalize early stopping
|
||||
layer = BCELossMasked(pos_weight=0.2)
|
||||
loss = layer(true_x, target, length)
|
||||
self.assertEqual(loss.item(), 0.0)
|
||||
|
||||
# when pos_weight < 1 overweight the early stopping loss
|
||||
loss_early = layer(early_x, target, length)
|
||||
loss_late = layer(late_x, target, length)
|
||||
self.assertGreater(loss_early.item(), loss_late.item())
|
||||
|
|
Loading…
Reference in New Issue