mirror of https://github.com/coqui-ai/TTS.git
masked loss
This commit is contained in:
parent
b1beb1f876
commit
e4a0eec77e
|
@ -0,0 +1,48 @@
|
|||
import torch
|
||||
from torch import functional
|
||||
|
||||
|
||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||
def _sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = sequence_length.data.max()
|
||||
batch_size = sequence_length.size(0)
|
||||
seq_range = torch.range(0, max_len - 1).long()
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
seq_range_expand = Variable(seq_range_expand)
|
||||
if sequence_length.is_cuda:
|
||||
seq_range_expand = seq_range_expand.cuda()
|
||||
seq_length_expand = (sequence_length.unsqueeze(1)
|
||||
.expand_as(seq_range_expand))
|
||||
return seq_range_expand < seq_length_expand
|
||||
|
||||
|
||||
def L1LossMasked(input, target, length):
|
||||
"""
|
||||
Args:
|
||||
logits: A Variable containing a FloatTensor of size
|
||||
(batch, max_len, num_classes) which contains the
|
||||
unnormalized probability for each class.
|
||||
target: A Variable containing a LongTensor of size
|
||||
(batch, max_len) which contains the index of the true
|
||||
class for each corresponding step.
|
||||
length: A Variable containing a LongTensor of size (batch,)
|
||||
which contains the length of each data in a batch.
|
||||
Returns:
|
||||
loss: An average loss value masked by the length.
|
||||
"""
|
||||
|
||||
# logits_flat: (batch * max_len, num_classes)
|
||||
input = input.view(-1, input.size(-1))
|
||||
# target_flat: (batch * max_len, 1)
|
||||
target_flat = target.view(-1, 1)
|
||||
# losses_flat: (batch * max_len, 1)
|
||||
losees_flat = functional.l1_loss(input, target, size_average=False,
|
||||
reduce=False)
|
||||
# losses: (batch, max_len)
|
||||
losses = losses_flat.view(*target.size())
|
||||
# mask: (batch, max_len)
|
||||
mask = _sequence_mask(sequence_length=length, max_len=target.size(1))
|
||||
losses = losses * mask.float()
|
||||
loss = losses.sum() / length.float().sum()
|
||||
return loss
|
Loading…
Reference in New Issue