From 19c37b28490a9456ab02d308aa29b8690ccc9376 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Thu, 22 Mar 2018 13:46:52 -0700 Subject: [PATCH] masked loss --- datasets/LJSpeech.py | 20 ++++++++++++-------- tests/layers_tests.py | 6 +----- tests/loader_tests.py | 16 +++++++++++----- train.py | 21 +++++++++++++++------ 4 files changed, 39 insertions(+), 24 deletions(-) diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index fb6c9304..7b50e646 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -98,9 +98,6 @@ class LJSpeechDataset(Dataset): mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] mel_lengths = [m.shape[1] for m in mel] - # compute 'stop token' targets - stop_targets = [np.array([0.]*mel_len) for mel_len in mel_lengths] - # PAD sequences with largest length of the batch text = prepare_data(text).astype(np.int32) wav = prepare_data(wav) @@ -111,9 +108,6 @@ class LJSpeechDataset(Dataset): assert mel.shape[2] == linear.shape[2] timesteps = mel.shape[2] - # PAD stop targets - stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) - # PAD with zeros that can be divided by outputs per step if (timesteps + 1) % self.outputs_per_step != 0: pad_len = self.outputs_per_step - \ @@ -123,8 +117,17 @@ class LJSpeechDataset(Dataset): pad_len = 1 linear = pad_per_step(linear, pad_len) mel = pad_per_step(mel, pad_len) + + # update mel lengths + mel_lengths = [l+pad_len for l in mel_lengths] + + # compute 'stop token' targets + stop_targets = [np.array([0.]*mel_len) for mel_len in mel_lengths] + + # PAD stop targets + stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) - # reshape mojo + # B x T x D linear = linear.transpose(0, 2, 1) mel = mel.transpose(0, 2, 1) @@ -133,8 +136,9 @@ class LJSpeechDataset(Dataset): text = torch.LongTensor(text) linear = torch.FloatTensor(linear) mel = torch.FloatTensor(mel) + mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) - return text, text_lenghts, linear, mel, stop_targets, item_idxs[0] + return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0] raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ found {}" diff --git a/tests/layers_tests.py b/tests/layers_tests.py index e8ebba0d..14739bf9 100644 --- a/tests/layers_tests.py +++ b/tests/layers_tests.py @@ -37,16 +37,12 @@ class DecoderTests(unittest.TestCase): dummy_memory = T.autograd.Variable(T.rand(4, 120, 32)) print(layer) - output, alignment, stop_output = layer(dummy_input, dummy_memory) + output, alignment = layer(dummy_input, dummy_memory) print(output.shape) - print(" > Stop ", stop_output.shape) assert output.shape[0] == 4 assert output.shape[1] == 120 / 5 assert output.shape[2] == 32 * 5 - assert stop_output.shape[0] == 4 - assert stop_output.shape[1] == 120 / 5 - assert stop_output.shape[2] == 5 class EncoderTests(unittest.TestCase): diff --git a/tests/loader_tests.py b/tests/loader_tests.py index dc023b60..3b3d017c 100644 --- a/tests/loader_tests.py +++ b/tests/loader_tests.py @@ -43,9 +43,10 @@ class TestDataset(unittest.TestCase): text_lengths = data[1] linear_input = data[2] mel_input = data[3] - stop_targets = data[4] - item_idx = data[5] - + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] + neg_values = text_input[text_input < 0] check_count = len(neg_values) assert check_count == 0, \ @@ -82,8 +83,9 @@ class TestDataset(unittest.TestCase): text_lengths = data[1] linear_input = data[2] mel_input = data[3] - stop_target = data[4] - item_idx = data[5] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] # check the last time step to be zero padded assert mel_input[0, -1].sum() == 0 @@ -92,6 +94,10 @@ class TestDataset(unittest.TestCase): assert linear_input[0, -2].sum() != 0 assert stop_target[0, -1] == 1 assert stop_target.sum() == 1 + assert len(mel_lengths.shape) == 1 + print(mel_lengths) + print(mel_input) + assert mel_lengths[0] == mel_input[0].shape[0] diff --git a/train.py b/train.py index 3e1a75c7..b39b17d9 100644 --- a/train.py +++ b/train.py @@ -26,6 +26,7 @@ from utils.model import get_param_size from utils.visual import plot_alignment, plot_spectrogram from datasets.LJSpeech import LJSpeechDataset from models.tacotron import Tacotron +from losses import use_cuda = torch.cuda.is_available() @@ -80,6 +81,7 @@ def train(model, criterion, data_loader, optimizer, epoch): text_lengths = data[1] linear_input = data[2] mel_input = data[3] + mel_lengths = data[4] current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1 @@ -93,6 +95,7 @@ def train(model, criterion, data_loader, optimizer, epoch): # convert inputs to variables text_input_var = Variable(text_input) mel_spec_var = Variable(mel_input) + mel_length_var = Variable(mel_lengths) linear_spec_var = Variable(linear_input, volatile=True) # sort sequence by length for curriculum learning @@ -108,6 +111,7 @@ def train(model, criterion, data_loader, optimizer, epoch): if use_cuda: text_input_var = text_input_var.cuda() mel_spec_var = mel_spec_var.cuda() + mel_lengths_var = mel_lengths_var.cuda() linear_spec_var = linear_spec_var.cuda() # forward pass @@ -115,10 +119,11 @@ def train(model, criterion, data_loader, optimizer, epoch): model.forward(text_input_var, mel_spec_var) # loss computation - mel_loss = criterion(mel_output, mel_spec_var) - linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \ + mel_loss = criterion(mel_output, mel_spec_var, mel_lengths) + linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths) \ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], - linear_spec_var[: ,: ,:n_priority_freq]) + linear_spec_var[: ,: ,:n_priority_freq], + mel_lengths) loss = mel_loss + linear_loss # backpass and check the grad norm @@ -215,26 +220,30 @@ def evaluate(model, criterion, data_loader, current_step): text_lengths = data[1] linear_input = data[2] mel_input = data[3] + mel_lengths = data[4] # convert inputs to variables text_input_var = Variable(text_input) mel_spec_var = Variable(mel_input) + mel_lengths_var = Variable(mel_lengths) linear_spec_var = Variable(linear_input, volatile=True) # dispatch data to GPU if use_cuda: text_input_var = text_input_var.cuda() mel_spec_var = mel_spec_var.cuda() + mel_lengths_var = mel_lengths_var.cuda() linear_spec_var = linear_spec_var.cuda() # forward pass mel_output, linear_output, alignments = model.forward(text_input_var, mel_spec_var) # loss computation - mel_loss = criterion(mel_output, mel_spec_var) - linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \ + mel_loss = criterion(mel_output, mel_spec_var, mel_lengths) + linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths) \ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], - linear_spec_var[: ,: ,:n_priority_freq]) + linear_spec_var[: ,: ,:n_priority_freq], + mel_lengths) loss = mel_loss + linear_loss step_time = time.time() - start_time