loss bug fix

This commit is contained in:
Eren Golge 2018-03-28 18:20:56 -07:00
parent 0582346969
commit 58257c4a06
2 changed files with 8 additions and 8 deletions

View File

@ -27,11 +27,11 @@ class L1LossMasked(nn.Module):
def forward(self, input, target, length): def forward(self, input, target, length):
""" """
Args: Args:
logits: A Variable containing a FloatTensor of size input: A Variable containing a FloatTensor of size
(batch, max_len, num_classes) which contains the (batch, max_len, dim) which contains the
unnormalized probability for each class. unnormalized probability for each class.
target: A Variable containing a LongTensor of size target: A Variable containing a LongTensor of size
(batch, max_len) which contains the index of the true (batch, max_len, dim) which contains the index of the true
class for each corresponding step. class for each corresponding step.
length: A Variable containing a LongTensor of size (batch,) length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch. which contains the length of each data in a batch.
@ -42,12 +42,12 @@ class L1LossMasked(nn.Module):
target = target.contiguous() target = target.contiguous()
# logits_flat: (batch * max_len, dim) # logits_flat: (batch * max_len, dim)
input = input.view(-1, input.size(-1)) input = input.view(-1, input.shape[-1])
# target_flat: (batch * max_len, dim) # target_flat: (batch * max_len, dim)
target_flat = target.view(-1, 1) target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim) # losses_flat: (batch * max_len, dim)
losses_flat = functional.l1_loss(input, target, size_average=False, losses_flat = functional.l1_loss(input, target, size_average=False,
reduce=False) reduce=False)
# losses: (batch, max_len, dim) # losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size()) losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1) # mask: (batch, max_len, 1)

View File

@ -2,7 +2,7 @@ import unittest
import torch as T import torch as T
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
from layers.losses import L1LossMasked, _sequence_mask from TTS.layers.losses import L1LossMasked, _sequence_mask
class PrenetTests(unittest.TestCase): class PrenetTests(unittest.TestCase):
@ -66,7 +66,7 @@ class L1LossMaskedTests(unittest.TestCase):
dummy_target = T.autograd.Variable(T.ones(4, 8, 128).float()) dummy_target = T.autograd.Variable(T.ones(4, 8, 128).float())
dummy_length = T.autograd.Variable((T.ones(4) * 8).long()) dummy_length = T.autograd.Variable((T.ones(4) * 8).long())
output = layer(dummy_input, dummy_target, dummy_length) output = layer(dummy_input, dummy_target, dummy_length)
assert output.shape[0] == 1 assert output.shape[0] == 0
assert len(output.shape) == 1 assert len(output.shape) == 1
assert output.data[0] == 0.0 assert output.data[0] == 0.0