BCE masked loss and padding stop_tokens with 0s not 1s

This commit is contained in:
erogol 2020-02-25 14:17:20 +01:00
parent a5d919cec8
commit a68012aec2
3 changed files with 34 additions and 5 deletions

View File

@ -96,3 +96,32 @@ class AttentionEntropyLoss(nn.Module):
entropy = torch.distributions.Categorical(probs=align).entropy()
loss = (entropy / np.log(align.shape[1])).mean()
return loss
class BCELossMasked(nn.Module):
def __init__(self, pos_weight):
super(BCELossMasked, self).__init__()
self.pos_weight = pos_weight
def forward(self, x, target, length):
"""
Args:
x: A Variable containing a FloatTensor of size
(batch, max_len) which contains the
unnormalized probability for each class.
target: A Variable containing a LongTensor of size
(batch, max_len) which contains the index of the true
class for each corresponding step.
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value in range [0, 1] masked by the length.
"""
# mask: (batch, max_len, 1)
target.requires_grad = False
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float()
loss = functional.binary_cross_entropy_with_logits(
x * mask, target * mask, pos_weight=self.pos_weight, reduction='sum')
loss = loss / mask.sum()
return loss

View File

@ -13,7 +13,7 @@ from torch.utils.data import DataLoader
from TTS.datasets.TTSDataset import MyDataset
from distribute import (DistributedSampler, apply_gradient_allreduce,
init_distributed, reduce_tensor)
from TTS.layers.losses import L1LossMasked, MSELossMasked
from TTS.layers.losses import L1LossMasked, MSELossMasked, BCELossMasked
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import (
NoamLR, check_update, count_parameters, create_experiment_folder,
@ -167,7 +167,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
# loss computation
stop_loss = criterion_st(stop_tokens,
stop_targets) if c.stopnet else torch.zeros(1)
stop_targets, mel_lengths) if c.stopnet else torch.zeros(1)
if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
if c.model in ["Tacotron", "TacotronGST"]:
@ -365,7 +365,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
# loss computation
stop_loss = criterion_st(
stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
stop_tokens, stop_targets, mel_lengths) if c.stopnet else torch.zeros(1)
if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input,
mel_lengths)
@ -571,7 +571,7 @@ def main(args): # pylint: disable=redefined-outer-name
else:
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
] else nn.MSELoss()
criterion_st = nn.BCEWithLogitsLoss(
criterion_st = BCELossMasked(
pos_weight=torch.tensor(10)) if c.stopnet else None
if args.restore_path:

View File

@ -31,7 +31,7 @@ def prepare_tensor(inputs, out_steps):
def _pad_stop_target(x, length):
_pad = 1.
_pad = 0.
assert x.ndim == 1
return np.pad(
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)