masked loss

This commit is contained in:
Eren Golge 2018-03-22 14:35:02 -07:00
parent b10abadada
commit 96efe83a41
2 changed files with 20 additions and 19 deletions

View File

@ -1,5 +1,6 @@
import torch import torch
from torch import functional from torch.nn import functional
from torch.autograd import Variable
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 # from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
@ -7,7 +8,7 @@ def _sequence_mask(sequence_length, max_len=None):
if max_len is None: if max_len is None:
max_len = sequence_length.data.max() max_len = sequence_length.data.max()
batch_size = sequence_length.size(0) batch_size = sequence_length.size(0)
seq_range = torch.range(0, max_len - 1).long() seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_range_expand = Variable(seq_range_expand) seq_range_expand = Variable(seq_range_expand)
if sequence_length.is_cuda: if sequence_length.is_cuda:
@ -31,18 +32,20 @@ def L1LossMasked(input, target, length):
Returns: Returns:
loss: An average loss value masked by the length. loss: An average loss value masked by the length.
""" """
input = input.contiguous()
target = target.contiguous()
# logits_flat: (batch * max_len, num_classes) # logits_flat: (batch * max_len, dim)
input = input.view(-1, input.size(-1)) input = input.view(-1, input.size(-1))
# target_flat: (batch * max_len, 1) # target_flat: (batch * max_len, dim)
target_flat = target.view(-1, 1) target_flat = target.view(-1, 1)
# losses_flat: (batch * max_len, 1) # losses_flat: (batch * max_len, dim)
losees_flat = functional.l1_loss(input, target, size_average=False, losses_flat = functional.l1_loss(input, target, size_average=False,
reduce=False) reduce=False)
# losses: (batch, max_len) # losses: (batch, max_len)
losses = losses_flat.view(*target.size()) losses = losses_flat.view(*target.size())
# mask: (batch, max_len) # mask: (batch, max_len)
mask = _sequence_mask(sequence_length=length, max_len=target.size(1)) mask = _sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float() losses = losses * mask.float()
loss = losses.sum() / length.float().sum() loss = losses.sum() / length.float().sum()
return loss return loss / input.shape[0]

View File

@ -26,7 +26,7 @@ from utils.model import get_param_size
from utils.visual import plot_alignment, plot_spectrogram from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron from models.tacotron import Tacotron
from losses import from layers.losses import L1LossMasked
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
@ -95,7 +95,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
# convert inputs to variables # convert inputs to variables
text_input_var = Variable(text_input) text_input_var = Variable(text_input)
mel_spec_var = Variable(mel_input) mel_spec_var = Variable(mel_input)
mel_length_var = Variable(mel_lengths) mel_lengths_var = Variable(mel_lengths)
linear_spec_var = Variable(linear_input, volatile=True) linear_spec_var = Variable(linear_input, volatile=True)
# sort sequence by length for curriculum learning # sort sequence by length for curriculum learning
@ -105,6 +105,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
sorted_lengths = sorted_lengths.long().numpy() sorted_lengths = sorted_lengths.long().numpy()
text_input_var = text_input_var[indices] text_input_var = text_input_var[indices]
mel_spec_var = mel_spec_var[indices] mel_spec_var = mel_spec_var[indices]
mel_lengths_var = mel_lengths_var[indices]
linear_spec_var = linear_spec_var[indices] linear_spec_var = linear_spec_var[indices]
# dispatch data to GPU # dispatch data to GPU
@ -119,11 +120,11 @@ def train(model, criterion, data_loader, optimizer, epoch):
model.forward(text_input_var, mel_spec_var) model.forward(text_input_var, mel_spec_var)
# loss computation # loss computation
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths) mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var)
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths) \ linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq], + 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec_var[: ,: ,:n_priority_freq], linear_spec_var[: ,: ,:n_priority_freq],
mel_lengths) mel_lengths_var)
loss = mel_loss + linear_loss loss = mel_loss + linear_loss
# backpass and check the grad norm # backpass and check the grad norm
@ -240,10 +241,10 @@ def evaluate(model, criterion, data_loader, current_step):
# loss computation # loss computation
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths) mel_loss = criterion(mel_output, mel_spec_var, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths) \ linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq], + 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec_var[: ,: ,:n_priority_freq], linear_spec_var[: ,: ,:n_priority_freq],
mel_lengths) mel_lengths_var)
loss = mel_loss + linear_loss loss = mel_loss + linear_loss
step_time = time.time() - start_time step_time = time.time() - start_time
@ -348,10 +349,7 @@ def main(args):
optimizer = optim.Adam(model.parameters(), lr=c.lr) optimizer = optim.Adam(model.parameters(), lr=c.lr)
if use_cuda: criterion = L1LossMasked
criterion = nn.L1Loss().cuda()
else:
criterion = nn.L1Loss()
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path) checkpoint = torch.load(args.restore_path)