diff --git a/layers/attention.py b/layers/attention.py index 958c5701..8d993cea 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -30,15 +30,15 @@ class BahdanauAttention(nn.Module): return alignment.squeeze(-1) -def get_mask_from_lengths(memory, memory_lengths): +def get_mask_from_lengths(inputs, inputs_lengths): """Get mask tensor from list of length Args: - memory: (batch, max_time, dim) - memory_lengths: array like + inputs: (batch, max_time, dim) + inputs_lengths: array like """ - mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_() - for idx, l in enumerate(memory_lengths): + mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_() + for idx, l in enumerate(inputs_lengths): mask[idx][:l] = 1 return ~mask @@ -51,14 +51,14 @@ class AttentionWrapper(nn.Module): self.alignment_model = alignment_model self.score_mask_value = score_mask_value - def forward(self, query, context_vec, cell_state, memory, - processed_inputs=None, mask=None, memory_lengths=None): + def forward(self, query, context_vec, cell_state, inputs, + processed_inputs=None, mask=None, inputs_lengths=None): if processed_inputs is None: - processed_inputs = memory + processed_inputs = inputs - if memory_lengths is not None and mask is None: - mask = get_mask_from_lengths(memory, memory_lengths) + if inputs_lengths is not None and mask is None: + mask = get_mask_from_lengths(inputs, inputs_lengths) # Alignment # (batch, max_time) @@ -77,7 +77,7 @@ class AttentionWrapper(nn.Module): # Attention context vector # (batch, 1, dim) # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j - context_vec = torch.bmm(alignment.unsqueeze(1), memory) + context_vec = torch.bmm(alignment.unsqueeze(1), inputs) context_vec = context_vec.squeeze(1) # Concat input query and previous context_vec context diff --git a/layers/tacotron.py b/layers/tacotron.py index 4fbe7f17..b0f4dfe9 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -7,35 +7,59 @@ from .attention import BahdanauAttention, AttentionWrapper from .attention import get_mask_from_lengths class Prenet(nn.Module): - def __init__(self, in_dim, sizes=[256, 128]): + r""" Prenet as explained at https://arxiv.org/abs/1703.10135. + It creates as many layers as given by 'out_features' + + Args: + in_features (int): size of the input vector + out_features (int or list): size of each output sample. + If it is a list, for each value, there is created a new layer. + """ + + def __init__(self, in_features, out_features=[256, 128]): super(Prenet, self).__init__() - in_sizes = [in_dim] + sizes[:-1] + in_features = [in_features] + out_features[:-1] self.layers = nn.ModuleList( [nn.Linear(in_size, out_size) - for (in_size, out_size) in zip(in_sizes, sizes)]) + for (in_size, out_size) in zip(in_features, out_features)]) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.5) def forward(self, inputs): for linear in self.layers: inputs = self.dropout(self.relu(linear(inputs))) - return inputs class BatchNormConv1d(nn.Module): - def __init__(self, in_dim, out_dim, kernel_size, stride, padding, + r"""A wrapper for Conv1d with BatchNorm. It sets the activation + function between Conv and BatchNorm layers. BatchNorm layer + is initialized with the TF default values for momentum and eps. + + Args: + in_channels: size of each input sample + out_channels: size of each output samples + kernel_size: kernel size of conv filters + stride: stride of conv filters + padding: padding of conv filters + activation: activation function set b/w Conv1d and BatchNorm + + Shapes: + - input: batch x dims + - output: batch x dims + """ + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation=None): super(BatchNormConv1d, self).__init__() - self.conv1d = nn.Conv1d(in_dim, out_dim, + self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) # Following tensorflow's default parameters - self.bn = nn.BatchNorm1d(out_dim, momentum=0.99, eps=1e-3) + self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3) self.activation = activation def forward(self, x): - x = self.conv1d(x) + x = self.conv1d(x) if self.activation is not None: x = self.activation(x) return self.bn(x) @@ -62,86 +86,109 @@ class CBHG(nn.Module): - 1-d convolution banks - Highway networks + residual connections - Bidirectional gated recurrent units + + Args: + in_features (int): sample size + K (int): max filter size in conv bank + projections (list): conv channel sizes for conv projections + num_highways (int): number of highways layers + + Shapes: + - input: batch x time x dim + - output: batch x time x dim*2 """ - def __init__(self, in_dim, K=16, projections=[128, 128]): + def __init__(self, in_features, K=16, projections=[128, 128], num_highways=4): super(CBHG, self).__init__() - self.in_dim = in_dim + self.in_features = in_features self.relu = nn.ReLU() + + # list of conv1d bank with filter size k=1...K + # TODO: try dilational layers instead self.conv1d_banks = nn.ModuleList( - [BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1, + [BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1, padding=k // 2, activation=self.relu) - for k in range(1, K + 1)]) + for k in range(1, K + 1)]) + + # max pooling of conv bank + # TODO: try average pooling OR larger kernel size self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) - in_sizes = [K * in_dim] + projections[:-1] - activations = [self.relu] * (len(projections) - 1) + [None] - self.conv1d_projections = nn.ModuleList( - [BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, - padding=1, activation=ac) - for (in_size, out_size, ac) in zip( - in_sizes, projections, activations)]) + out_features = [K * in_features] + projections[:-1] + activations = [self.relu] * (len(projections) - 1) + activations += [None] - self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False) + # setup conv1d projection layers + layer_set = [] + for (in_size, out_size, ac) in zip(out_features, projections, activations): + layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, + padding=1, activation=ac) + layer_set.append(layer) + self.conv1d_projections = nn.ModuleList(layer_set) + + # setup Highway layers + self.pre_highway = nn.Linear(projections[-1], in_features, bias=False) self.highways = nn.ModuleList( - [Highway(in_dim, in_dim) for _ in range(4)]) + [Highway(in_features, in_features) for _ in range(num_highways)]) + # bi-directional GPU layer self.gru = nn.GRU( - in_dim, in_dim, 1, batch_first=True, bidirectional=True) + in_features, in_features, 1, batch_first=True, bidirectional=True) def forward(self, inputs): - # (B, T_in, in_dim) + # (B, T_in, in_features) x = inputs # Needed to perform conv1d on time-axis - # (B, in_dim, T_in) - if x.size(-1) == self.in_dim: + # (B, in_features, T_in) + if x.size(-1) == self.in_features: x = x.transpose(1, 2) T = x.size(-1) - # (B, in_dim*K, T_in) + # (B, in_features*K, T_in) # Concat conv1d bank outputs - x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1) - assert x.size(1) == self.in_dim * len(self.conv1d_banks) + outs = [] + for conv1d in self.conv1d_banks: + out = conv1d(x) + out = out[:, :, :T] + outs.append(out) + + x = torch.cat(outs, dim=1) + assert x.size(1) == self.in_features * len(self.conv1d_banks) + x = self.max_pool1d(x)[:, :, :T] for conv1d in self.conv1d_projections: x = conv1d(x) - # (B, T_in, in_dim) + # (B, T_in, in_features) # Back to the original shape x = x.transpose(1, 2) - if x.size(-1) != self.in_dim: + if x.size(-1) != self.in_features: x = self.pre_highway(x) # Residual connection + # TODO: try residual scaling as in Deep Voice 3 + # TODO: try plain residual layers x += inputs for highway in self.highways: x = highway(x) - # if input_lengths is not None: - # print(x.size()) - # print(len(input_lengths)) - # x = nn.utils.rnn.pack_padded_sequence( - # x, input_lengths.data.cpu().numpy(), batch_first=True) - - # (B, T_in, in_dim*2) - self.gru.flatten_parameters() + # (B, T_in, in_features*2) + # TODO: replace GRU with convolution as in Deep Voice 3 + self.gru.flatten_parameters() outputs, _ = self.gru(x) - - #if input_lengths is not None: - # outputs, _ = nn.utils.rnn.pad_packed_sequence( - # outputs, batch_first=True) - return outputs class Encoder(nn.Module): - def __init__(self, in_dim): + r"""Encapsulate Prenet and CBHG modules for encoder""" + + def __init__(self, in_features): super(Encoder, self).__init__() - self.prenet = Prenet(in_dim, sizes=[256, 128]) + self.prenet = Prenet(in_features, out_features=[256, 128]) self.cbhg = CBHG(128, K=16, projections=[128, 128]) def forward(self, inputs): @@ -150,22 +197,32 @@ class Encoder(nn.Module): class Decoder(nn.Module): - def __init__(self, memory_dim, r): + r"""Decoder module. + + Args: + memory_dim (int): memory vector sample size + r (int): number of outputs per time step + + Shape: + - input: + - output: + """ + def __init__(self, in_features, memory_dim, r): super(Decoder, self).__init__() self.max_decoder_steps = 200 self.memory_dim = memory_dim self.r = r # input -> |Linear| -> processed_inputs - self.input_layer = nn.Linear(256, 256, bias=False) + self.input_layer = nn.Linear(in_features, 256, bias=False) # memory -> |Prenet| -> processed_memory - self.prenet = Prenet(memory_dim * r, sizes=[256, 128]) - # processed_inputs, prrocessed_memory -> |Attention| -> Attention, Alignment, RNN_State + self.prenet = Prenet(memory_dim * r, out_features=[256, 128]) + # processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State self.attention_rnn = AttentionWrapper( - nn.GRUCell(256 + 128, 256), + nn.GRUCell(in_features + 128, 256), BahdanauAttention(256) ) - # (prenet_out | attention context) -> |Linear| -> decoder_RNN_input - self.project_to_decoder_in = nn.Linear(512, 256) + # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input + self.project_to_decoder_in = nn.Linear(256+in_features, 256) # decoder_RNN_input -> |RNN| -> RNN_state self.decoder_rnns = nn.ModuleList( [nn.GRUCell(256, 256) for _ in range(2)]) @@ -173,22 +230,26 @@ class Decoder(nn.Module): self.proj_to_mel = nn.Linear(256, memory_dim * r) def forward(self, inputs, memory=None, memory_lengths=None): - """ + r""" Decoder forward step. If decoder inputs are not given (e.g., at testing time), as noted in Tacotron paper, greedy decoding is adapted. Args: - inputs: Encoder outputs. (B, T_encoder, dim) - memory: Decoder memory. i.e., mel-spectrogram. If None (at eval-time), + inputs: Encoder outputs. + memory: Decoder memory (autoregression. If None (at eval-time), decoder outputs are used as decoder inputs. memory_lengths: Encoder output (memory) lengths. If not None, used for attention masking. + + Shapes: + - inputs: batch x time x encoder_out_dim + - memory: batch x #mels_pecs x mel_spec_dim """ B = inputs.size(0) - # TODO: take thi segment into Attention module. + # TODO: take this segment into Attention module. processed_inputs = self.input_layer(inputs) if memory_lengths is not None: mask = get_mask_from_lengths(processed_inputs, memory_lengths) @@ -199,9 +260,12 @@ class Decoder(nn.Module): greedy = memory is None if memory is not None: + # Grouping multiple frames if necessary if memory.size(-1) == self.memory_dim: + print(" > Blamento", memory.shape) memory = memory.view(B, memory.size(1) // self.r, -1) + print(" > Blamento", memory.shape) assert memory.size(-1) == self.memory_dim * self.r,\ " !! Dimension mismatch {} vs {} * {}".format(memory.size(-1), self.memory_dim, self.r) @@ -233,11 +297,11 @@ class Decoder(nn.Module): if t > 0: memory_input = outputs[-1] if greedy else memory[t - 1] # Prenet - memory_input = self.prenet(memory_input) + processed_memory = self.prenet(memory_input) # Attention RNN attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn( - memory_input, current_context_vec, attention_rnn_hidden, + processed_memory, current_context_vec, attention_rnn_hidden, inputs, processed_inputs=processed_inputs, mask=mask) # Concat RNN output and attention context vector diff --git a/tests/layers_tests.py b/tests/layers_tests.py index 9ca3c612..0b4bddd1 100644 --- a/tests/layers_tests.py +++ b/tests/layers_tests.py @@ -1 +1,45 @@ -import unittest \ No newline at end of file +import unittest +import torch as T + +from TTS.layers.tacotron import Prenet, CBHG, Decoder + + +class PrenetTests(unittest.TestCase): + + def test_in_out(self): + layer = Prenet(128, out_features=[256, 128]) + dummy_input = T.autograd.Variable(T.rand(4, 128)) + + + print(layer) + output = layer(dummy_input) + assert output.shape[0] == 4 + assert output.shape[1] == 128 + + +class CBHGTests(unittest.TestCase): + + def test_in_out(self): + layer = CBHG(128, K= 6, projections=[128, 128], num_highways=2) + dummy_input = T.autograd.Variable(T.rand(4, 8, 128)) + + print(layer) + output = layer(dummy_input) + assert output.shape[0] == 4 + assert output.shape[1] == 8 + assert output.shape[2] == 256 + + +class DecoderTests(unittest.TestCase): + + def test_in_out(self): + layer = Decoder(in_features=128, memory_dim=32, r=5) + dummy_input = T.autograd.Variable(T.rand(4, 8, 128)) + dummy_memory = T.autograd.Variable(T.rand(4, 120, 32)) + + print(layer) + output, alignment = layer(dummy_input, dummy_memory) + print(output.shape) + assert output.shape[0] == 4 + assert output.shape[1] == 120 / 5 + assert output.shape[2] == 32 * 5 diff --git a/tests/loader_tests.py b/tests/loader_tests.py index e251f767..14da30d5 100644 --- a/tests/loader_tests.py +++ b/tests/loader_tests.py @@ -41,9 +41,8 @@ class TestDataset(unittest.TestCase): break text_input = data[0] text_lengths = data[1] - print(text_lengths) - magnitude_input = data[2] mel_input = data[3] + item_idx = data[4] neg_values = text_input[text_input < 0] check_count = len(neg_values)