mirror of https://github.com/coqui-ai/TTS.git
Fix BCELoss adressing #1192
This commit is contained in:
parent
eefd482f51
commit
a6f73a18cb
|
@ -53,7 +53,7 @@ class TacotronConfig(BaseTTSConfig):
|
||||||
enable /disable the Stopnet that predicts the end of the decoder sequence. Defaults to True.
|
enable /disable the Stopnet that predicts the end of the decoder sequence. Defaults to True.
|
||||||
stopnet_pos_weight (float):
|
stopnet_pos_weight (float):
|
||||||
Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with
|
Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with
|
||||||
datasets with longer sentences. Defaults to 10.
|
datasets with longer sentences. Defaults to 0.2.
|
||||||
max_decoder_steps (int):
|
max_decoder_steps (int):
|
||||||
Max number of steps allowed for the decoder. Defaults to 50.
|
Max number of steps allowed for the decoder. Defaults to 50.
|
||||||
encoder_in_features (int):
|
encoder_in_features (int):
|
||||||
|
@ -161,7 +161,7 @@ class TacotronConfig(BaseTTSConfig):
|
||||||
prenet_dropout_at_inference: bool = False
|
prenet_dropout_at_inference: bool = False
|
||||||
stopnet: bool = True
|
stopnet: bool = True
|
||||||
separate_stopnet: bool = True
|
separate_stopnet: bool = True
|
||||||
stopnet_pos_weight: float = 10.0
|
stopnet_pos_weight: float = 0.2
|
||||||
max_decoder_steps: int = 500
|
max_decoder_steps: int = 500
|
||||||
encoder_in_features: int = 256
|
encoder_in_features: int = 256
|
||||||
decoder_in_features: int = 256
|
decoder_in_features: int = 256
|
||||||
|
|
|
@ -147,9 +147,6 @@ class AttentionEntropyLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
Forces attention to be more decisive by penalizing
|
Forces attention to be more decisive by penalizing
|
||||||
soft attention weights
|
soft attention weights
|
||||||
|
|
||||||
TODO: arguments
|
|
||||||
TODO: unit_test
|
|
||||||
"""
|
"""
|
||||||
entropy = torch.distributions.Categorical(probs=align).entropy()
|
entropy = torch.distributions.Categorical(probs=align).entropy()
|
||||||
loss = (entropy / np.log(align.shape[1])).mean()
|
loss = (entropy / np.log(align.shape[1])).mean()
|
||||||
|
@ -157,9 +154,9 @@ class AttentionEntropyLoss(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class BCELossMasked(nn.Module):
|
class BCELossMasked(nn.Module):
|
||||||
def __init__(self, pos_weight):
|
def __init__(self, pos_weight:float=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pos_weight = pos_weight
|
self.pos_weight = torch.tensor([pos_weight])
|
||||||
|
|
||||||
def forward(self, x, target, length):
|
def forward(self, x, target, length):
|
||||||
"""
|
"""
|
||||||
|
@ -179,16 +176,15 @@ class BCELossMasked(nn.Module):
|
||||||
Returns:
|
Returns:
|
||||||
loss: An average loss value in range [0, 1] masked by the length.
|
loss: An average loss value in range [0, 1] masked by the length.
|
||||||
"""
|
"""
|
||||||
# mask: (batch, max_len, 1)
|
|
||||||
target.requires_grad = False
|
target.requires_grad = False
|
||||||
if length is not None:
|
if length is not None:
|
||||||
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float()
|
# mask: (batch, max_len, 1)
|
||||||
x = x * mask
|
mask = sequence_mask(sequence_length=length, max_len=target.size(1))
|
||||||
target = target * mask
|
|
||||||
num_items = mask.sum()
|
num_items = mask.sum()
|
||||||
|
loss = functional.binary_cross_entropy_with_logits(x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum")
|
||||||
else:
|
else:
|
||||||
num_items = torch.numel(x)
|
|
||||||
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
|
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
|
||||||
|
num_items = torch.numel(x)
|
||||||
loss = loss / num_items
|
loss = loss / num_items
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import unittest
|
import unittest
|
||||||
import torch as T
|
import torch as T
|
||||||
|
from torch.nn import functional
|
||||||
|
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
from TTS.tts.layers.losses import L1LossMasked, SSIMLoss, MSELossMasked
|
from TTS.tts.layers.losses import L1LossMasked, SSIMLoss, MSELossMasked, BCELossMasked
|
||||||
|
|
||||||
|
|
||||||
class L1LossMaskedTests(unittest.TestCase):
|
class L1LossMaskedTests(unittest.TestCase):
|
||||||
|
@ -200,3 +201,39 @@ class SSIMLossTests(unittest.TestCase):
|
||||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
|
||||||
|
class BCELossTest(unittest.TestCase):
|
||||||
|
def test_in_out(self): # pylint: disable=no-self-use
|
||||||
|
layer = BCELossMasked(pos_weight=5.0)
|
||||||
|
|
||||||
|
length = T.tensor([95])
|
||||||
|
mask = sequence_mask(length, 100)
|
||||||
|
pos_weight = T.tensor([5.0])
|
||||||
|
target = 1. - sequence_mask(length - 1, 100).float() # [0, 0, .... 1, 1] where the first 1 is the last mel frame
|
||||||
|
true_x = target * 200 - 100 # creates logits of [-100, -100, ... 100, 100] corresponding to target
|
||||||
|
zero_x = T.zeros(target.shape) - 100. # simulate logits if it never stops decoding
|
||||||
|
early_x = -200. * sequence_mask(length - 3, 100).float() + 100. # simulate logits on early stopping
|
||||||
|
late_x = -200. * sequence_mask(length + 1, 100).float() + 100. # simulate logits on late stopping
|
||||||
|
|
||||||
|
loss = layer(true_x, target, length)
|
||||||
|
self.assertEqual(loss.item(), 0.0)
|
||||||
|
|
||||||
|
loss = layer(early_x, target, length)
|
||||||
|
self.assertAlmostEqual(loss.item(), 2.1053, places=4)
|
||||||
|
|
||||||
|
loss = layer(late_x, target, length)
|
||||||
|
self.assertAlmostEqual(loss.item(), 5.2632, places=4)
|
||||||
|
|
||||||
|
loss = layer(zero_x, target, length)
|
||||||
|
self.assertAlmostEqual(loss.item(), 5.2632, places=4)
|
||||||
|
|
||||||
|
# pos_weight should be < 1 to penalize early stopping
|
||||||
|
layer = BCELossMasked(pos_weight=0.2)
|
||||||
|
loss = layer(true_x, target, length)
|
||||||
|
self.assertEqual(loss.item(), 0.0)
|
||||||
|
|
||||||
|
# when pos_weight < 1 overweight the early stopping loss
|
||||||
|
loss_early = layer(early_x, target, length)
|
||||||
|
loss_late = layer(late_x, target, length)
|
||||||
|
self.assertGreater(loss_early.item(), loss_late.item())
|
||||||
|
|
Loading…
Reference in New Issue