mirror of https://github.com/coqui-ai/TTS.git
loss bug fix
This commit is contained in:
parent
0582346969
commit
58257c4a06
|
@ -27,11 +27,11 @@ class L1LossMasked(nn.Module):
|
|||
def forward(self, input, target, length):
|
||||
"""
|
||||
Args:
|
||||
logits: A Variable containing a FloatTensor of size
|
||||
(batch, max_len, num_classes) which contains the
|
||||
input: A Variable containing a FloatTensor of size
|
||||
(batch, max_len, dim) which contains the
|
||||
unnormalized probability for each class.
|
||||
target: A Variable containing a LongTensor of size
|
||||
(batch, max_len) which contains the index of the true
|
||||
(batch, max_len, dim) which contains the index of the true
|
||||
class for each corresponding step.
|
||||
length: A Variable containing a LongTensor of size (batch,)
|
||||
which contains the length of each data in a batch.
|
||||
|
@ -42,12 +42,12 @@ class L1LossMasked(nn.Module):
|
|||
target = target.contiguous()
|
||||
|
||||
# logits_flat: (batch * max_len, dim)
|
||||
input = input.view(-1, input.size(-1))
|
||||
input = input.view(-1, input.shape[-1])
|
||||
# target_flat: (batch * max_len, dim)
|
||||
target_flat = target.view(-1, 1)
|
||||
target_flat = target.view(-1, target.shape[-1])
|
||||
# losses_flat: (batch * max_len, dim)
|
||||
losses_flat = functional.l1_loss(input, target, size_average=False,
|
||||
reduce=False)
|
||||
reduce=False)
|
||||
# losses: (batch, max_len, dim)
|
||||
losses = losses_flat.view(*target.size())
|
||||
# mask: (batch, max_len, 1)
|
||||
|
|
|
@ -2,7 +2,7 @@ import unittest
|
|||
import torch as T
|
||||
|
||||
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
|
||||
from layers.losses import L1LossMasked, _sequence_mask
|
||||
from TTS.layers.losses import L1LossMasked, _sequence_mask
|
||||
|
||||
|
||||
class PrenetTests(unittest.TestCase):
|
||||
|
@ -66,7 +66,7 @@ class L1LossMaskedTests(unittest.TestCase):
|
|||
dummy_target = T.autograd.Variable(T.ones(4, 8, 128).float())
|
||||
dummy_length = T.autograd.Variable((T.ones(4) * 8).long())
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.shape[0] == 1
|
||||
assert output.shape[0] == 0
|
||||
assert len(output.shape) == 1
|
||||
assert output.data[0] == 0.0
|
||||
|
||||
|
|
Loading…
Reference in New Issue