diff --git a/config.json b/config.json index ffea4466..ed9764b1 100644 --- a/config.json +++ b/config.json @@ -7,25 +7,24 @@ "preemphasis": 0.97, "min_level_db": -100, "ref_level_db": 20, - "hidden_size": 128, "embedding_size": 256, "text_cleaner": "english_cleaners", "epochs": 2000, "lr": 0.001, "warmup_steps": 4000, - "batch_size": 32, - "eval_batch_size": 32, + "batch_size": 128, + "eval_batch_size":32, "r": 5, "griffin_lim_iters": 60, "power": 1.5, - "num_loader_workers": 12, + "num_loader_workers": 8, - "checkpoint": false, - "save_step": 69, + "checkpoint": true, + "save_step": 378, "data_path": "/run/shm/erogol/LJSpeech-1.0", "min_seq_len": 0, - "output_path": "result" + "output_path": "/data/shared/erogol_models/" } diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index 5334e1ca..a773c661 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -7,7 +7,8 @@ from torch.utils.data import Dataset from TTS.utils.text import text_to_sequence from TTS.utils.audio import AudioProcessor -from TTS.utils.data import prepare_data, pad_data, pad_per_step +from TTS.utils.data import (prepare_data, pad_per_step, + prepare_tensor, prepare_stop_target) class LJSpeechDataset(Dataset): @@ -93,26 +94,27 @@ class LJSpeechDataset(Dataset): text_lenghts = np.array([len(x) for x in text]) max_text_len = np.max(text_lenghts) + linear = [self.ap.spectrogram(w).astype('float32') for w in wav] + mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] + mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame + + # compute 'stop token' targets + stop_targets = [np.array([0.]*(mel_len-1)) for mel_len in mel_lengths] + + # PAD stop targets + stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) + # PAD sequences with largest length of the batch text = prepare_data(text).astype(np.int32) wav = prepare_data(wav) - linear = np.array([self.ap.spectrogram(w).astype('float32') for w in wav]) - mel = np.array([self.ap.melspectrogram(w).astype('float32') for w in wav]) + # PAD features with largest length + a zero frame + linear = prepare_tensor(linear, self.outputs_per_step) + mel = prepare_tensor(mel, self.outputs_per_step) assert mel.shape[2] == linear.shape[2] - timesteps = mel.shape[2] + timesteps = mel.shape[2] - # PAD with zeros that can be divided by outputs per step - if (timesteps + 1) % self.outputs_per_step != 0: - pad_len = self.outputs_per_step - \ - ((timesteps + 1) % self.outputs_per_step) - pad_len += 1 - else: - pad_len = 1 - linear = pad_per_step(linear, pad_len) - mel = pad_per_step(mel, pad_len) - - # reshape jombo + # B x T x D linear = linear.transpose(0, 2, 1) mel = mel.transpose(0, 2, 1) @@ -121,7 +123,10 @@ class LJSpeechDataset(Dataset): text = torch.LongTensor(text) linear = torch.FloatTensor(linear) mel = torch.FloatTensor(mel) - return text, text_lenghts, linear, mel, item_idxs[0] + mel_lengths = torch.LongTensor(mel_lengths) + stop_targets = torch.FloatTensor(stop_targets) + + return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0] raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ found {}" diff --git a/layers/.tacotron.py.swo b/layers/.tacotron.py.swo deleted file mode 100644 index c637f447..00000000 Binary files a/layers/.tacotron.py.swo and /dev/null differ diff --git a/layers/attention.py b/layers/attention.py index e7385149..1626e949 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -48,7 +48,7 @@ class AttentionRNN(nn.Module): def __init__(self, out_dim, annot_dim, memory_dim, score_mask_value=-float("inf")): super(AttentionRNN, self).__init__() - self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, out_dim) + self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim) self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim) self.score_mask_value = score_mask_value @@ -57,11 +57,19 @@ class AttentionRNN(nn.Module): if annotations_lengths is not None and mask is None: mask = get_mask_from_lengths(annotations, annotations_lengths) + + # Concat input query and previous context context + rnn_input = torch.cat((memory, context), -1) + #rnn_input = rnn_input.unsqueeze(1) + + # Feed it to RNN + # s_i = f(y_{i-1}, c_{i}, s_{i-1}) + rnn_output = self.rnn_cell(rnn_input, rnn_state) # Alignment # (batch, max_time) # e_{ij} = a(s_{i-1}, h_j) - alignment = self.alignment_model(annotations, rnn_state) + alignment = self.alignment_model(annotations, rnn_output) # TODO: needs recheck. if mask is not None: @@ -75,16 +83,6 @@ class AttentionRNN(nn.Module): # (batch, 1, dim) # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j context = torch.bmm(alignment.unsqueeze(1), annotations) - context = context.squeeze(1) - - # Concat input query and previous context context - rnn_input = torch.cat((memory, context), -1) - #rnn_input = rnn_input.unsqueeze(1) - - # Feed it to RNN - # s_i = f(y_{i-1}, c_{i}, s_{i-1}) - rnn_output = self.rnn_cell(rnn_input, rnn_state) - context = context.squeeze(1) return rnn_output, context, alignment diff --git a/layers/custom_layers.py b/layers/custom_layers.py new file mode 100644 index 00000000..d659efb2 --- /dev/null +++ b/layers/custom_layers.py @@ -0,0 +1,26 @@ +# coding: utf-8 +import torch +from torch.autograd import Variable +from torch import nn + + +# class StopProjection(nn.Module): +# r""" Simple projection layer to predict the "stop token" + +# Args: +# in_features (int): size of the input vector +# out_features (int or list): size of each output vector. aka number +# of predicted frames. +# """ + +# def __init__(self, in_features, out_features): +# super(StopProjection, self).__init__() +# self.linear = nn.Linear(in_features, out_features) +# self.dropout = nn.Dropout(0.5) +# self.sigmoid = nn.Sigmoid() + +# def forward(self, inputs): +# out = self.dropout(inputs) +# out = self.linear(out) +# out = self.sigmoid(out) +# return out \ No newline at end of file diff --git a/layers/losses.py b/layers/losses.py new file mode 100644 index 00000000..67bc0f22 --- /dev/null +++ b/layers/losses.py @@ -0,0 +1,57 @@ +import torch +from torch.nn import functional +from torch.autograd import Variable +from torch import nn + + +# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 +def _sequence_mask(sequence_length, max_len=None): + if max_len is None: + max_len = sequence_length.data.max() + batch_size = sequence_length.size(0) + seq_range = torch.arange(0, max_len).long() + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_range_expand = Variable(seq_range_expand) + if sequence_length.is_cuda: + seq_range_expand = seq_range_expand.cuda() + seq_length_expand = (sequence_length.unsqueeze(1) + .expand_as(seq_range_expand)) + return seq_range_expand < seq_length_expand + + +class L1LossMasked(nn.Module): + + def __init__(self): + super(L1LossMasked, self).__init__() + + def forward(self, input, target, length): + """ + Args: + logits: A Variable containing a FloatTensor of size + (batch, max_len, num_classes) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + Returns: + loss: An average loss value masked by the length. + """ + input = input.contiguous() + target = target.contiguous() + + # logits_flat: (batch * max_len, dim) + input = input.view(-1, input.size(-1)) + # target_flat: (batch * max_len, dim) + target_flat = target.view(-1, 1) + # losses_flat: (batch * max_len, dim) + losses_flat = functional.l1_loss(input, target, size_average=False, + reduce=False) + # losses: (batch, max_len, dim) + losses = losses_flat.view(*target.size()) + # mask: (batch, max_len, 1) + mask = _sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2) + losses = losses * mask.float() + loss = losses.sum() / (length.float().sum() * float(target.shape[2])) + return loss \ No newline at end of file diff --git a/layers/tacotron.py b/layers/tacotron.py index 38471214..916ea677 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -48,6 +48,7 @@ class BatchNormConv1d(nn.Module): - input: batch x dims - output: batch x dims """ + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation=None): super(BatchNormConv1d, self).__init__() @@ -213,8 +214,9 @@ class Decoder(nn.Module): r (int): number of outputs per time step. eps (float): threshold for detecting the end of a sentence. """ - def __init__(self, in_features, memory_dim, r, eps=0.05): + def __init__(self, in_features, memory_dim, r, eps=0.05, mode='train'): super(Decoder, self).__init__() + self.mode = mode self.max_decoder_steps = 200 self.memory_dim = memory_dim self.eps = eps @@ -241,7 +243,8 @@ class Decoder(nn.Module): Args: inputs: Encoder outputs. memory (None): Decoder memory (autoregression. If None (at eval-time), - decoder outputs are used as decoder inputs. + decoder outputs are used as decoder inputs. If None, it uses the last + output as the input. Shapes: - inputs: batch x time x encoder_out_dim @@ -250,14 +253,13 @@ class Decoder(nn.Module): B = inputs.size(0) # Run greedy decoding if memory is None - greedy = memory is None + greedy = not self.training if memory is not None: - + # Grouping multiple frames if necessary if memory.size(-1) == self.memory_dim: memory = memory.view(B, memory.size(1) // self.r, -1) - assert memory.size(-1) == self.memory_dim * self.r,\ " !! Dimension mismatch {} vs {} * {}".format(memory.size(-1), self.memory_dim, self.r) T_decoder = memory.size(1) @@ -286,15 +288,23 @@ class Decoder(nn.Module): memory_input = initial_memory while True: if t > 0: - memory_input = outputs[-1] if greedy else memory[t - 1] + if greedy: + memory_input = outputs[-1] + else: + # combine prev. model output and prev. real target + # memory_input = torch.div(outputs[-1] + memory[t-1], 2.0) + # add a random noise + # noise = torch.autograd.Variable( + # memory_input.data.new(memory_input.size()).normal_(0.0, 0.5)) + # memory_input = memory_input + noise + memory_input = memory[t-1] # Prenet processed_memory = self.prenet(memory_input) # Attention RNN attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn( - processed_memory, current_context_vec, attention_rnn_hidden, - inputs) + processed_memory, current_context_vec, attention_rnn_hidden, inputs) # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( @@ -306,8 +316,9 @@ class Decoder(nn.Module): decoder_input, decoder_rnn_hiddens[idx]) # Residual connectinon decoder_input = decoder_rnn_hiddens[idx] + decoder_input - + output = decoder_input + # predict mel vectors from decoder vectors output = self.proj_to_mel(output) @@ -317,17 +328,17 @@ class Decoder(nn.Module): t += 1 - if greedy: + if (not greedy and self.training) or (greedy and memory is not None): + if t >= T_decoder: + break + else: if t > 1 and is_end_of_frames(output, self.eps): break elif t > self.max_decoder_steps: print(" !! Decoder stopped with 'max_decoder_steps'. \ Something is probably wrong.") break - else: - if t >= T_decoder: - break - + assert greedy or len(outputs) == T_decoder # Back to batch first @@ -338,4 +349,4 @@ class Decoder(nn.Module): def is_end_of_frames(output, eps=0.2): #0.2 - return (output.data <= eps).all() + return (output.data <= eps).all() \ No newline at end of file diff --git a/models/.tacotron.py.swo b/models/.tacotron.py.swo deleted file mode 100644 index b4cfd7c5..00000000 Binary files a/models/.tacotron.py.swo and /dev/null differ diff --git a/models/tacotron.py b/models/tacotron.py index 7653f1c3..05bb1292 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -8,9 +8,10 @@ from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG class Tacotron(nn.Module): def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80, - freq_dim=1025, r=5, padding_idx=None): + r=5, padding_idx=None): super(Tacotron, self).__init__() + self.r = r self.mel_dim = mel_dim self.linear_dim = linear_dim self.embedding = nn.Embedding(len(symbols), embedding_dim, @@ -23,9 +24,10 @@ class Tacotron(nn.Module): self.decoder = Decoder(256, mel_dim, r) self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim]) - self.last_linear = nn.Linear(mel_dim * 2, freq_dim) + self.last_linear = nn.Linear(mel_dim * 2, linear_dim) def forward(self, characters, mel_specs=None): + B = characters.size(0) inputs = self.embedding(characters) diff --git a/tests/layers_tests.py b/tests/layers_tests.py index 3fbab022..570b474c 100644 --- a/tests/layers_tests.py +++ b/tests/layers_tests.py @@ -2,6 +2,7 @@ import unittest import torch as T from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder +from layers.losses import L1LossMasked, _sequence_mask class PrenetTests(unittest.TestCase): @@ -32,23 +33,22 @@ class CBHGTests(unittest.TestCase): class DecoderTests(unittest.TestCase): def test_in_out(self): - layer = Decoder(in_features=128, memory_dim=32, r=5) - dummy_input = T.autograd.Variable(T.rand(4, 8, 128)) - dummy_memory = T.autograd.Variable(T.rand(4, 120, 32)) + layer = Decoder(in_features=256, memory_dim=80, r=2) + dummy_input = T.autograd.Variable(T.rand(4, 8, 256)) + dummy_memory = T.autograd.Variable(T.rand(4, 2, 80)) - print(layer) output, alignment = layer(dummy_input, dummy_memory) - print(output.shape) + assert output.shape[0] == 4 - assert output.shape[1] == 120 / 5 - assert output.shape[2] == 32 * 5 - + assert output.shape[1] == 1, "size not {}".format(output.shape[1]) + assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2]) + class EncoderTests(unittest.TestCase): def test_in_out(self): layer = Encoder(128) - dummy_input = T.autograd.Variable(T.rand(4, 8, 128)) + dummy_input = T.autograd.Variable(T.rand(4, 8, 128)) print(layer) output = layer(dummy_input) @@ -56,4 +56,29 @@ class EncoderTests(unittest.TestCase): assert output.shape[0] == 4 assert output.shape[1] == 8 assert output.shape[2] == 256 # 128 * 2 BiRNN + +class L1LossMaskedTests(unittest.TestCase): + + def test_in_out(self): + layer = L1LossMasked() + dummy_input = T.autograd.Variable(T.ones(4, 8, 128).float()) + dummy_target = T.autograd.Variable(T.ones(4, 8, 128).float()) + dummy_length = T.autograd.Variable((T.ones(4) * 8).long()) + output = layer(dummy_input, dummy_target, dummy_length) + assert output.shape[0] == 1 + assert len(output.shape) == 1 + assert output.data[0] == 0.0 + + dummy_input = T.autograd.Variable(T.ones(4, 8, 128).float()) + dummy_target = T.autograd.Variable(T.zeros(4, 8, 128).float()) + dummy_length = T.autograd.Variable((T.ones(4) * 8).long()) + output = layer(dummy_input, dummy_target, dummy_length) + assert output.data[0] == 1.0, "1.0 vs {}".format(output.data[0]) + + dummy_input = T.autograd.Variable(T.ones(4, 8, 128).float()) + dummy_target = T.autograd.Variable(T.zeros(4, 8, 128).float()) + dummy_length = T.autograd.Variable((T.arange(5,9)).long()) + mask = ((_sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert output.data[0] == 1.0, "1.0 vs {}".format(output.data[0]) diff --git a/tests/loader_tests.py b/tests/loader_tests.py index fdecd6eb..769fbebe 100644 --- a/tests/loader_tests.py +++ b/tests/loader_tests.py @@ -32,7 +32,7 @@ class TestDataset(unittest.TestCase): c.power ) - dataloader = DataLoader(dataset, batch_size=c.batch_size, + dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=c.num_loader_workers) @@ -43,8 +43,10 @@ class TestDataset(unittest.TestCase): text_lengths = data[1] linear_input = data[2] mel_input = data[3] - item_idx = data[4] - + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] + neg_values = text_input[text_input < 0] check_count = len(neg_values) assert check_count == 0, \ @@ -70,8 +72,9 @@ class TestDataset(unittest.TestCase): c.power ) + # Test for batch size 1 dataloader = DataLoader(dataset, batch_size=1, - shuffle=True, collate_fn=dataset.collate_fn, + shuffle=False, collate_fn=dataset.collate_fn, drop_last=True, num_workers=c.num_loader_workers) for i, data in enumerate(dataloader): @@ -81,13 +84,63 @@ class TestDataset(unittest.TestCase): text_lengths = data[1] linear_input = data[2] mel_input = data[3] - item_idx = data[4] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] # check the last time step to be zero padded assert mel_input[0, -1].sum() == 0 assert mel_input[0, -2].sum() != 0 assert linear_input[0, -1].sum() == 0 assert linear_input[0, -2].sum() != 0 + assert stop_target[0, -1] == 1 + assert stop_target[0, -2] == 0 + assert stop_target.sum() == 1 + assert len(mel_lengths.shape) == 1 + assert mel_lengths[0] == mel_input[0].shape[0] + + # Test for batch size 2 + dataloader = DataLoader(dataset, batch_size=2, + shuffle=False, collate_fn=dataset.collate_fn, + drop_last=False, num_workers=c.num_loader_workers) + + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] + + if mel_lengths[0] > mel_lengths[1]: + idx = 0 + else: + idx = 1 + + # check the first item in the batch + assert mel_input[idx, -1].sum() == 0 + assert mel_input[idx, -2].sum() != 0, mel_input + assert linear_input[idx, -1].sum() == 0 + assert linear_input[idx, -2].sum() != 0 + assert stop_target[idx, -1] == 1 + assert stop_target[idx, -2] == 0 + assert stop_target[idx].sum() == 1 + assert len(mel_lengths.shape) == 1 + assert mel_lengths[idx] == mel_input[idx].shape[0] + + # check the second itme in the batch + assert mel_input[1-idx, -1].sum() == 0 + assert linear_input[1-idx, -1].sum() == 0 + assert stop_target[1-idx, -1] == 1 + assert len(mel_lengths.shape) == 1 + + # check batch conditions + assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 + assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 + diff --git a/train.py b/train.py index 7b32d74c..87908717 100644 --- a/train.py +++ b/train.py @@ -26,6 +26,7 @@ from utils.model import get_param_size from utils.visual import plot_alignment, plot_spectrogram from datasets.LJSpeech import LJSpeechDataset from models.tacotron import Tacotron +from layers.losses import L1LossMasked use_cuda = torch.cuda.is_available() @@ -80,7 +81,8 @@ def train(model, criterion, data_loader, optimizer, epoch): text_lengths = data[1] linear_input = data[2] mel_input = data[3] - + mel_lengths = data[4] + current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1 # setup lr @@ -93,21 +95,14 @@ def train(model, criterion, data_loader, optimizer, epoch): # convert inputs to variables text_input_var = Variable(text_input) mel_spec_var = Variable(mel_input) + mel_lengths_var = Variable(mel_lengths) linear_spec_var = Variable(linear_input, volatile=True) - # sort sequence by length for curriculum learning - # TODO: might be unnecessary - sorted_lengths, indices = torch.sort( - text_lengths.view(-1), dim=0, descending=True) - sorted_lengths = sorted_lengths.long().numpy() - text_input_var = text_input_var[indices] - mel_spec_var = mel_spec_var[indices] - linear_spec_var = linear_spec_var[indices] - # dispatch data to GPU if use_cuda: text_input_var = text_input_var.cuda() mel_spec_var = mel_spec_var.cuda() + mel_lengths_var = mel_lengths_var.cuda() linear_spec_var = linear_spec_var.cuda() # forward pass @@ -115,10 +110,11 @@ def train(model, criterion, data_loader, optimizer, epoch): model.forward(text_input_var, mel_spec_var) # loss computation - mel_loss = criterion(mel_output, mel_spec_var) - linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \ + mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var) + linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], - linear_spec_var[: ,: ,:n_priority_freq]) + linear_spec_var[: ,: ,:n_priority_freq], + mel_lengths_var) loss = mel_loss + linear_loss # backpass and check the grad norm @@ -215,28 +211,31 @@ def evaluate(model, criterion, data_loader, current_step): text_lengths = data[1] linear_input = data[2] mel_input = data[3] + mel_lengths = data[4] # convert inputs to variables text_input_var = Variable(text_input) mel_spec_var = Variable(mel_input) + mel_lengths_var = Variable(mel_lengths) linear_spec_var = Variable(linear_input, volatile=True) # dispatch data to GPU if use_cuda: text_input_var = text_input_var.cuda() mel_spec_var = mel_spec_var.cuda() + mel_lengths_var = mel_lengths_var.cuda() linear_spec_var = linear_spec_var.cuda() # forward pass - mel_output, linear_output, alignments =\ - model.forward(text_input_var, mel_spec_var) + mel_output, linear_output, alignments = model.forward(text_input_var, mel_spec_var) # loss computation - mel_loss = criterion(mel_output, mel_spec_var) - linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \ + mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var) + linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], - linear_spec_var[: ,: ,:n_priority_freq]) - loss = mel_loss + linear_loss + linear_spec_var[: ,: ,:n_priority_freq], + mel_lengths_var) + loss = mel_loss + linear_loss step_time = time.time() - start_time epoch_time += step_time @@ -333,17 +332,16 @@ def main(args): pin_memory=True) model = Tacotron(c.embedding_size, - c.hidden_size, - c.num_mels, c.num_freq, + c.num_mels, c.r) - + optimizer = optim.Adam(model.parameters(), lr=c.lr) if use_cuda: - criterion = nn.L1Loss().cuda() + criterion = L1LossMasked().cuda() else: - criterion = nn.L1Loss() + criterion = L1LossMasked() if args.restore_path: checkpoint = torch.load(args.restore_path) diff --git a/utils/data.py b/utils/data.py index a38092e9..6c47d5eb 100644 --- a/utils/data.py +++ b/utils/data.py @@ -1,7 +1,7 @@ import numpy as np -def pad_data(x, length): +def _pad_data(x, length): _pad = 0 assert x.ndim == 1 return np.pad(x, (0, length - x.shape[0]), @@ -11,7 +11,33 @@ def pad_data(x, length): def prepare_data(inputs): max_len = max((len(x) for x in inputs)) - return np.stack([pad_data(x, max_len) for x in inputs]) + return np.stack([_pad_data(x, max_len) for x in inputs]) + + +def _pad_tensor(x, length): + _pad = 0 + assert x.ndim == 2 + x = np.pad(x, [[0, 0], [0, length - x.shape[1]]], mode='constant', constant_values=_pad) + return x + +def prepare_tensor(inputs, out_steps): + max_len = max((x.shape[1] for x in inputs)) + 1 # zero-frame + remainder = max_len % out_steps + pad_len = max_len + (out_steps - remainder) if remainder > 0 else max_len + return np.stack([_pad_tensor(x, pad_len) for x in inputs]) + + +def _pad_stop_target(x, length): + _pad = 1. + assert x.ndim == 1 + return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) + + +def prepare_stop_target(inputs, out_steps): + max_len = max((x.shape[0] for x in inputs)) + 1 # zero-frame + remainder = max_len % out_steps + pad_len = max_len + (out_steps - remainder) if remainder > 0 else max_len + return np.stack([_pad_stop_target(x, pad_len) for x in inputs]) def pad_per_step(inputs, pad_len):