convert loss to layer and add test

This commit is contained in:
Eren Golge 2018-03-24 19:22:45 -07:00
parent df4a644326
commit 1dbc51c6b5
3 changed files with 66 additions and 31 deletions

View File

@ -1,6 +1,7 @@
import torch import torch
from torch.nn import functional from torch.nn import functional
from torch.autograd import Variable from torch.autograd import Variable
from torch import nn
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 # from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
@ -18,34 +19,39 @@ def _sequence_mask(sequence_length, max_len=None):
return seq_range_expand < seq_length_expand return seq_range_expand < seq_length_expand
def L1LossMasked(input, target, length): class L1LossMasked(nn.Module):
"""
Args: def __init__(self):
logits: A Variable containing a FloatTensor of size super(L1LossMasked, self).__init__()
(batch, max_len, num_classes) which contains the
unnormalized probability for each class. def forward(self, input, target, length):
target: A Variable containing a LongTensor of size """
(batch, max_len) which contains the index of the true Args:
class for each corresponding step. logits: A Variable containing a FloatTensor of size
length: A Variable containing a LongTensor of size (batch,) (batch, max_len, num_classes) which contains the
which contains the length of each data in a batch. unnormalized probability for each class.
Returns: target: A Variable containing a LongTensor of size
loss: An average loss value masked by the length. (batch, max_len) which contains the index of the true
""" class for each corresponding step.
input = input.contiguous() length: A Variable containing a LongTensor of size (batch,)
target = target.contiguous() which contains the length of each data in a batch.
Returns:
loss: An average loss value masked by the length.
"""
input = input.contiguous()
target = target.contiguous()
# logits_flat: (batch * max_len, dim) # logits_flat: (batch * max_len, dim)
input = input.view(-1, input.size(-1)) input = input.view(-1, input.size(-1))
# target_flat: (batch * max_len, dim) # target_flat: (batch * max_len, dim)
target_flat = target.view(-1, 1) target_flat = target.view(-1, 1)
# losses_flat: (batch * max_len, dim) # losses_flat: (batch * max_len, dim)
losses_flat = functional.l1_loss(input, target, size_average=False, losses_flat = functional.l1_loss(input, target, size_average=False,
reduce=False) reduce=False)
# losses: (batch, max_len) # losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size()) losses = losses_flat.view(*target.size())
# mask: (batch, max_len) # mask: (batch, max_len, 1)
mask = _sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2) mask = _sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float() losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * target.shape[2]) loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
return loss return loss

View File

@ -2,6 +2,7 @@ import unittest
import torch as T import torch as T
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
from layers.losses import L1LossMasked, _sequence_mask
class PrenetTests(unittest.TestCase): class PrenetTests(unittest.TestCase):
@ -57,4 +58,29 @@ class EncoderTests(unittest.TestCase):
assert output.shape[0] == 4 assert output.shape[0] == 4
assert output.shape[1] == 8 assert output.shape[1] == 8
assert output.shape[2] == 256 # 128 * 2 BiRNN assert output.shape[2] == 256 # 128 * 2 BiRNN
class L1LossMaskedTests(unittest.TestCase):
def test_in_out(self):
layer = L1LossMasked()
dummy_input = T.autograd.Variable(T.ones(4, 8, 128).float())
dummy_target = T.autograd.Variable(T.ones(4, 8, 128).float())
dummy_length = T.autograd.Variable((T.ones(4) * 8).long())
output = layer(dummy_input, dummy_target, dummy_length)
assert output.shape[0] == 1
assert len(output.shape) == 1
assert output.data[0] == 0.0
dummy_input = T.autograd.Variable(T.ones(4, 8, 128).float())
dummy_target = T.autograd.Variable(T.zeros(4, 8, 128).float())
dummy_length = T.autograd.Variable((T.ones(4) * 8).long())
output = layer(dummy_input, dummy_target, dummy_length)
assert output.data[0] == 1.0, "1.0 vs {}".format(output.data[0])
dummy_input = T.autograd.Variable(T.ones(4, 8, 128).float())
dummy_target = T.autograd.Variable(T.zeros(4, 8, 128).float())
dummy_length = T.autograd.Variable((T.arange(5,9)).long())
mask = ((_sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
assert output.data[0] == 1.0, "1.0 vs {}".format(output.data[0])

View File

@ -349,7 +349,10 @@ def main(args):
optimizer = optim.Adam(model.parameters(), lr=c.lr) optimizer = optim.Adam(model.parameters(), lr=c.lr)
criterion = L1LossMasked if use_cuda:
criterion = L1LossMasked().cuda()
else:
criterion = L1LossMasked()
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path) checkpoint = torch.load(args.restore_path)