From 96efe83a41ac033019cb88a4158e074742ddcf23 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Thu, 22 Mar 2018 14:35:02 -0700 Subject: [PATCH] masked loss --- layers/losses.py | 19 +++++++++++-------- train.py | 20 +++++++++----------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/layers/losses.py b/layers/losses.py index 0fdc654e..29ad7378 100644 --- a/layers/losses.py +++ b/layers/losses.py @@ -1,5 +1,6 @@ import torch -from torch import functional +from torch.nn import functional +from torch.autograd import Variable # from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 @@ -7,7 +8,7 @@ def _sequence_mask(sequence_length, max_len=None): if max_len is None: max_len = sequence_length.data.max() batch_size = sequence_length.size(0) - seq_range = torch.range(0, max_len - 1).long() + seq_range = torch.arange(0, max_len).long() seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_range_expand = Variable(seq_range_expand) if sequence_length.is_cuda: @@ -31,18 +32,20 @@ def L1LossMasked(input, target, length): Returns: loss: An average loss value masked by the length. """ + input = input.contiguous() + target = target.contiguous() - # logits_flat: (batch * max_len, num_classes) + # logits_flat: (batch * max_len, dim) input = input.view(-1, input.size(-1)) - # target_flat: (batch * max_len, 1) + # target_flat: (batch * max_len, dim) target_flat = target.view(-1, 1) - # losses_flat: (batch * max_len, 1) - losees_flat = functional.l1_loss(input, target, size_average=False, + # losses_flat: (batch * max_len, dim) + losses_flat = functional.l1_loss(input, target, size_average=False, reduce=False) # losses: (batch, max_len) losses = losses_flat.view(*target.size()) # mask: (batch, max_len) - mask = _sequence_mask(sequence_length=length, max_len=target.size(1)) + mask = _sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2) losses = losses * mask.float() loss = losses.sum() / length.float().sum() - return loss \ No newline at end of file + return loss / input.shape[0] \ No newline at end of file diff --git a/train.py b/train.py index b39b17d9..c4d34e2d 100644 --- a/train.py +++ b/train.py @@ -26,7 +26,7 @@ from utils.model import get_param_size from utils.visual import plot_alignment, plot_spectrogram from datasets.LJSpeech import LJSpeechDataset from models.tacotron import Tacotron -from losses import +from layers.losses import L1LossMasked use_cuda = torch.cuda.is_available() @@ -95,7 +95,7 @@ def train(model, criterion, data_loader, optimizer, epoch): # convert inputs to variables text_input_var = Variable(text_input) mel_spec_var = Variable(mel_input) - mel_length_var = Variable(mel_lengths) + mel_lengths_var = Variable(mel_lengths) linear_spec_var = Variable(linear_input, volatile=True) # sort sequence by length for curriculum learning @@ -105,6 +105,7 @@ def train(model, criterion, data_loader, optimizer, epoch): sorted_lengths = sorted_lengths.long().numpy() text_input_var = text_input_var[indices] mel_spec_var = mel_spec_var[indices] + mel_lengths_var = mel_lengths_var[indices] linear_spec_var = linear_spec_var[indices] # dispatch data to GPU @@ -119,11 +120,11 @@ def train(model, criterion, data_loader, optimizer, epoch): model.forward(text_input_var, mel_spec_var) # loss computation - mel_loss = criterion(mel_output, mel_spec_var, mel_lengths) - linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths) \ + mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var) + linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_spec_var[: ,: ,:n_priority_freq], - mel_lengths) + mel_lengths_var) loss = mel_loss + linear_loss # backpass and check the grad norm @@ -240,10 +241,10 @@ def evaluate(model, criterion, data_loader, current_step): # loss computation mel_loss = criterion(mel_output, mel_spec_var, mel_lengths) - linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths) \ + linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_spec_var[: ,: ,:n_priority_freq], - mel_lengths) + mel_lengths_var) loss = mel_loss + linear_loss step_time = time.time() - start_time @@ -348,10 +349,7 @@ def main(args): optimizer = optim.Adam(model.parameters(), lr=c.lr) - if use_cuda: - criterion = nn.L1Loss().cuda() - else: - criterion = nn.L1Loss() + criterion = L1LossMasked if args.restore_path: checkpoint = torch.load(args.restore_path)