mirror of https://github.com/coqui-ai/TTS.git
convert loss to layer and add test
This commit is contained in:
parent
c61a4419e5
commit
9f5dfcc59a
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional
|
from torch.nn import functional
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||||
|
@ -18,34 +19,39 @@ def _sequence_mask(sequence_length, max_len=None):
|
||||||
return seq_range_expand < seq_length_expand
|
return seq_range_expand < seq_length_expand
|
||||||
|
|
||||||
|
|
||||||
def L1LossMasked(input, target, length):
|
class L1LossMasked(nn.Module):
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
logits: A Variable containing a FloatTensor of size
|
|
||||||
(batch, max_len, num_classes) which contains the
|
|
||||||
unnormalized probability for each class.
|
|
||||||
target: A Variable containing a LongTensor of size
|
|
||||||
(batch, max_len) which contains the index of the true
|
|
||||||
class for each corresponding step.
|
|
||||||
length: A Variable containing a LongTensor of size (batch,)
|
|
||||||
which contains the length of each data in a batch.
|
|
||||||
Returns:
|
|
||||||
loss: An average loss value masked by the length.
|
|
||||||
"""
|
|
||||||
input = input.contiguous()
|
|
||||||
target = target.contiguous()
|
|
||||||
|
|
||||||
# logits_flat: (batch * max_len, dim)
|
def __init__(self):
|
||||||
input = input.view(-1, input.size(-1))
|
super(L1LossMasked, self).__init__()
|
||||||
# target_flat: (batch * max_len, dim)
|
|
||||||
target_flat = target.view(-1, 1)
|
def forward(self, input, target, length):
|
||||||
# losses_flat: (batch * max_len, dim)
|
"""
|
||||||
losses_flat = functional.l1_loss(input, target, size_average=False,
|
Args:
|
||||||
reduce=False)
|
logits: A Variable containing a FloatTensor of size
|
||||||
# losses: (batch, max_len)
|
(batch, max_len, num_classes) which contains the
|
||||||
losses = losses_flat.view(*target.size())
|
unnormalized probability for each class.
|
||||||
# mask: (batch, max_len)
|
target: A Variable containing a LongTensor of size
|
||||||
mask = _sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
(batch, max_len) which contains the index of the true
|
||||||
losses = losses * mask.float()
|
class for each corresponding step.
|
||||||
loss = losses.sum() / (length.float().sum() * target.shape[2])
|
length: A Variable containing a LongTensor of size (batch,)
|
||||||
return loss
|
which contains the length of each data in a batch.
|
||||||
|
Returns:
|
||||||
|
loss: An average loss value masked by the length.
|
||||||
|
"""
|
||||||
|
input = input.contiguous()
|
||||||
|
target = target.contiguous()
|
||||||
|
|
||||||
|
# logits_flat: (batch * max_len, dim)
|
||||||
|
input = input.view(-1, input.size(-1))
|
||||||
|
# target_flat: (batch * max_len, dim)
|
||||||
|
target_flat = target.view(-1, 1)
|
||||||
|
# losses_flat: (batch * max_len, dim)
|
||||||
|
losses_flat = functional.l1_loss(input, target, size_average=False,
|
||||||
|
reduce=False)
|
||||||
|
# losses: (batch, max_len, dim)
|
||||||
|
losses = losses_flat.view(*target.size())
|
||||||
|
# mask: (batch, max_len, 1)
|
||||||
|
mask = _sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
||||||
|
losses = losses * mask.float()
|
||||||
|
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||||
|
return loss
|
|
@ -2,6 +2,7 @@ import unittest
|
||||||
import torch as T
|
import torch as T
|
||||||
|
|
||||||
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
|
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
|
||||||
|
from layers.losses import L1LossMasked, _sequence_mask
|
||||||
|
|
||||||
|
|
||||||
class PrenetTests(unittest.TestCase):
|
class PrenetTests(unittest.TestCase):
|
||||||
|
@ -58,3 +59,28 @@ class EncoderTests(unittest.TestCase):
|
||||||
assert output.shape[1] == 8
|
assert output.shape[1] == 8
|
||||||
assert output.shape[2] == 256 # 128 * 2 BiRNN
|
assert output.shape[2] == 256 # 128 * 2 BiRNN
|
||||||
|
|
||||||
|
|
||||||
|
class L1LossMaskedTests(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_in_out(self):
|
||||||
|
layer = L1LossMasked()
|
||||||
|
dummy_input = T.autograd.Variable(T.ones(4, 8, 128).float())
|
||||||
|
dummy_target = T.autograd.Variable(T.ones(4, 8, 128).float())
|
||||||
|
dummy_length = T.autograd.Variable((T.ones(4) * 8).long())
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.shape[0] == 1
|
||||||
|
assert len(output.shape) == 1
|
||||||
|
assert output.data[0] == 0.0
|
||||||
|
|
||||||
|
dummy_input = T.autograd.Variable(T.ones(4, 8, 128).float())
|
||||||
|
dummy_target = T.autograd.Variable(T.zeros(4, 8, 128).float())
|
||||||
|
dummy_length = T.autograd.Variable((T.ones(4) * 8).long())
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.data[0] == 1.0, "1.0 vs {}".format(output.data[0])
|
||||||
|
|
||||||
|
dummy_input = T.autograd.Variable(T.ones(4, 8, 128).float())
|
||||||
|
dummy_target = T.autograd.Variable(T.zeros(4, 8, 128).float())
|
||||||
|
dummy_length = T.autograd.Variable((T.arange(5,9)).long())
|
||||||
|
mask = ((_sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert output.data[0] == 1.0, "1.0 vs {}".format(output.data[0])
|
||||||
|
|
5
train.py
5
train.py
|
@ -349,7 +349,10 @@ def main(args):
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
|
|
||||||
criterion = L1LossMasked
|
if use_cuda:
|
||||||
|
criterion = L1LossMasked().cuda()
|
||||||
|
else:
|
||||||
|
criterion = L1LossMasked()
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = torch.load(args.restore_path)
|
||||||
|
|
Loading…
Reference in New Issue