Testing of layers and documentation

This commit is contained in:
Eren Golge 2018-02-08 10:10:11 -08:00
parent 584c8fbf5e
commit 7d5bcd6ca4
4 changed files with 179 additions and 72 deletions

View File

@ -30,15 +30,15 @@ class BahdanauAttention(nn.Module):
return alignment.squeeze(-1)
def get_mask_from_lengths(memory, memory_lengths):
def get_mask_from_lengths(inputs, inputs_lengths):
"""Get mask tensor from list of length
Args:
memory: (batch, max_time, dim)
memory_lengths: array like
inputs: (batch, max_time, dim)
inputs_lengths: array like
"""
mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_()
for idx, l in enumerate(memory_lengths):
mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_()
for idx, l in enumerate(inputs_lengths):
mask[idx][:l] = 1
return ~mask
@ -51,14 +51,14 @@ class AttentionWrapper(nn.Module):
self.alignment_model = alignment_model
self.score_mask_value = score_mask_value
def forward(self, query, context_vec, cell_state, memory,
processed_inputs=None, mask=None, memory_lengths=None):
def forward(self, query, context_vec, cell_state, inputs,
processed_inputs=None, mask=None, inputs_lengths=None):
if processed_inputs is None:
processed_inputs = memory
processed_inputs = inputs
if memory_lengths is not None and mask is None:
mask = get_mask_from_lengths(memory, memory_lengths)
if inputs_lengths is not None and mask is None:
mask = get_mask_from_lengths(inputs, inputs_lengths)
# Alignment
# (batch, max_time)
@ -77,7 +77,7 @@ class AttentionWrapper(nn.Module):
# Attention context vector
# (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
context_vec = torch.bmm(alignment.unsqueeze(1), memory)
context_vec = torch.bmm(alignment.unsqueeze(1), inputs)
context_vec = context_vec.squeeze(1)
# Concat input query and previous context_vec context

View File

@ -7,35 +7,59 @@ from .attention import BahdanauAttention, AttentionWrapper
from .attention import get_mask_from_lengths
class Prenet(nn.Module):
def __init__(self, in_dim, sizes=[256, 128]):
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
It creates as many layers as given by 'out_features'
Args:
in_features (int): size of the input vector
out_features (int or list): size of each output sample.
If it is a list, for each value, there is created a new layer.
"""
def __init__(self, in_features, out_features=[256, 128]):
super(Prenet, self).__init__()
in_sizes = [in_dim] + sizes[:-1]
in_features = [in_features] + out_features[:-1]
self.layers = nn.ModuleList(
[nn.Linear(in_size, out_size)
for (in_size, out_size) in zip(in_sizes, sizes)])
for (in_size, out_size) in zip(in_features, out_features)])
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
def forward(self, inputs):
for linear in self.layers:
inputs = self.dropout(self.relu(linear(inputs)))
return inputs
class BatchNormConv1d(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, padding,
r"""A wrapper for Conv1d with BatchNorm. It sets the activation
function between Conv and BatchNorm layers. BatchNorm layer
is initialized with the TF default values for momentum and eps.
Args:
in_channels: size of each input sample
out_channels: size of each output samples
kernel_size: kernel size of conv filters
stride: stride of conv filters
padding: padding of conv filters
activation: activation function set b/w Conv1d and BatchNorm
Shapes:
- input: batch x dims
- output: batch x dims
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
activation=None):
super(BatchNormConv1d, self).__init__()
self.conv1d = nn.Conv1d(in_dim, out_dim,
self.conv1d = nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size,
stride=stride, padding=padding, bias=False)
# Following tensorflow's default parameters
self.bn = nn.BatchNorm1d(out_dim, momentum=0.99, eps=1e-3)
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
self.activation = activation
def forward(self, x):
x = self.conv1d(x)
x = self.conv1d(x)
if self.activation is not None:
x = self.activation(x)
return self.bn(x)
@ -62,86 +86,109 @@ class CBHG(nn.Module):
- 1-d convolution banks
- Highway networks + residual connections
- Bidirectional gated recurrent units
Args:
in_features (int): sample size
K (int): max filter size in conv bank
projections (list): conv channel sizes for conv projections
num_highways (int): number of highways layers
Shapes:
- input: batch x time x dim
- output: batch x time x dim*2
"""
def __init__(self, in_dim, K=16, projections=[128, 128]):
def __init__(self, in_features, K=16, projections=[128, 128], num_highways=4):
super(CBHG, self).__init__()
self.in_dim = in_dim
self.in_features = in_features
self.relu = nn.ReLU()
# list of conv1d bank with filter size k=1...K
# TODO: try dilational layers instead
self.conv1d_banks = nn.ModuleList(
[BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1,
[BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1,
padding=k // 2, activation=self.relu)
for k in range(1, K + 1)])
for k in range(1, K + 1)])
# max pooling of conv bank
# TODO: try average pooling OR larger kernel size
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
in_sizes = [K * in_dim] + projections[:-1]
activations = [self.relu] * (len(projections) - 1) + [None]
self.conv1d_projections = nn.ModuleList(
[BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
padding=1, activation=ac)
for (in_size, out_size, ac) in zip(
in_sizes, projections, activations)])
out_features = [K * in_features] + projections[:-1]
activations = [self.relu] * (len(projections) - 1)
activations += [None]
self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False)
# setup conv1d projection layers
layer_set = []
for (in_size, out_size, ac) in zip(out_features, projections, activations):
layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
padding=1, activation=ac)
layer_set.append(layer)
self.conv1d_projections = nn.ModuleList(layer_set)
# setup Highway layers
self.pre_highway = nn.Linear(projections[-1], in_features, bias=False)
self.highways = nn.ModuleList(
[Highway(in_dim, in_dim) for _ in range(4)])
[Highway(in_features, in_features) for _ in range(num_highways)])
# bi-directional GPU layer
self.gru = nn.GRU(
in_dim, in_dim, 1, batch_first=True, bidirectional=True)
in_features, in_features, 1, batch_first=True, bidirectional=True)
def forward(self, inputs):
# (B, T_in, in_dim)
# (B, T_in, in_features)
x = inputs
# Needed to perform conv1d on time-axis
# (B, in_dim, T_in)
if x.size(-1) == self.in_dim:
# (B, in_features, T_in)
if x.size(-1) == self.in_features:
x = x.transpose(1, 2)
T = x.size(-1)
# (B, in_dim*K, T_in)
# (B, in_features*K, T_in)
# Concat conv1d bank outputs
x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1)
assert x.size(1) == self.in_dim * len(self.conv1d_banks)
outs = []
for conv1d in self.conv1d_banks:
out = conv1d(x)
out = out[:, :, :T]
outs.append(out)
x = torch.cat(outs, dim=1)
assert x.size(1) == self.in_features * len(self.conv1d_banks)
x = self.max_pool1d(x)[:, :, :T]
for conv1d in self.conv1d_projections:
x = conv1d(x)
# (B, T_in, in_dim)
# (B, T_in, in_features)
# Back to the original shape
x = x.transpose(1, 2)
if x.size(-1) != self.in_dim:
if x.size(-1) != self.in_features:
x = self.pre_highway(x)
# Residual connection
# TODO: try residual scaling as in Deep Voice 3
# TODO: try plain residual layers
x += inputs
for highway in self.highways:
x = highway(x)
# if input_lengths is not None:
# print(x.size())
# print(len(input_lengths))
# x = nn.utils.rnn.pack_padded_sequence(
# x, input_lengths.data.cpu().numpy(), batch_first=True)
# (B, T_in, in_dim*2)
self.gru.flatten_parameters()
# (B, T_in, in_features*2)
# TODO: replace GRU with convolution as in Deep Voice 3
self.gru.flatten_parameters()
outputs, _ = self.gru(x)
#if input_lengths is not None:
# outputs, _ = nn.utils.rnn.pad_packed_sequence(
# outputs, batch_first=True)
return outputs
class Encoder(nn.Module):
def __init__(self, in_dim):
r"""Encapsulate Prenet and CBHG modules for encoder"""
def __init__(self, in_features):
super(Encoder, self).__init__()
self.prenet = Prenet(in_dim, sizes=[256, 128])
self.prenet = Prenet(in_features, out_features=[256, 128])
self.cbhg = CBHG(128, K=16, projections=[128, 128])
def forward(self, inputs):
@ -150,22 +197,32 @@ class Encoder(nn.Module):
class Decoder(nn.Module):
def __init__(self, memory_dim, r):
r"""Decoder module.
Args:
memory_dim (int): memory vector sample size
r (int): number of outputs per time step
Shape:
- input:
- output:
"""
def __init__(self, in_features, memory_dim, r):
super(Decoder, self).__init__()
self.max_decoder_steps = 200
self.memory_dim = memory_dim
self.r = r
# input -> |Linear| -> processed_inputs
self.input_layer = nn.Linear(256, 256, bias=False)
self.input_layer = nn.Linear(in_features, 256, bias=False)
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, sizes=[256, 128])
# processed_inputs, prrocessed_memory -> |Attention| -> Attention, Alignment, RNN_State
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
self.attention_rnn = AttentionWrapper(
nn.GRUCell(256 + 128, 256),
nn.GRUCell(in_features + 128, 256),
BahdanauAttention(256)
)
# (prenet_out | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(512, 256)
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state
self.decoder_rnns = nn.ModuleList(
[nn.GRUCell(256, 256) for _ in range(2)])
@ -173,22 +230,26 @@ class Decoder(nn.Module):
self.proj_to_mel = nn.Linear(256, memory_dim * r)
def forward(self, inputs, memory=None, memory_lengths=None):
"""
r"""
Decoder forward step.
If decoder inputs are not given (e.g., at testing time), as noted in
Tacotron paper, greedy decoding is adapted.
Args:
inputs: Encoder outputs. (B, T_encoder, dim)
memory: Decoder memory. i.e., mel-spectrogram. If None (at eval-time),
inputs: Encoder outputs.
memory: Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs.
memory_lengths: Encoder output (memory) lengths. If not None, used for
attention masking.
Shapes:
- inputs: batch x time x encoder_out_dim
- memory: batch x #mels_pecs x mel_spec_dim
"""
B = inputs.size(0)
# TODO: take thi segment into Attention module.
# TODO: take this segment into Attention module.
processed_inputs = self.input_layer(inputs)
if memory_lengths is not None:
mask = get_mask_from_lengths(processed_inputs, memory_lengths)
@ -199,9 +260,12 @@ class Decoder(nn.Module):
greedy = memory is None
if memory is not None:
# Grouping multiple frames if necessary
if memory.size(-1) == self.memory_dim:
print(" > Blamento", memory.shape)
memory = memory.view(B, memory.size(1) // self.r, -1)
print(" > Blamento", memory.shape)
assert memory.size(-1) == self.memory_dim * self.r,\
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
self.memory_dim, self.r)
@ -233,11 +297,11 @@ class Decoder(nn.Module):
if t > 0:
memory_input = outputs[-1] if greedy else memory[t - 1]
# Prenet
memory_input = self.prenet(memory_input)
processed_memory = self.prenet(memory_input)
# Attention RNN
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
memory_input, current_context_vec, attention_rnn_hidden,
processed_memory, current_context_vec, attention_rnn_hidden,
inputs, processed_inputs=processed_inputs, mask=mask)
# Concat RNN output and attention context vector

View File

@ -1 +1,45 @@
import unittest
import unittest
import torch as T
from TTS.layers.tacotron import Prenet, CBHG, Decoder
class PrenetTests(unittest.TestCase):
def test_in_out(self):
layer = Prenet(128, out_features=[256, 128])
dummy_input = T.autograd.Variable(T.rand(4, 128))
print(layer)
output = layer(dummy_input)
assert output.shape[0] == 4
assert output.shape[1] == 128
class CBHGTests(unittest.TestCase):
def test_in_out(self):
layer = CBHG(128, K= 6, projections=[128, 128], num_highways=2)
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
print(layer)
output = layer(dummy_input)
assert output.shape[0] == 4
assert output.shape[1] == 8
assert output.shape[2] == 256
class DecoderTests(unittest.TestCase):
def test_in_out(self):
layer = Decoder(in_features=128, memory_dim=32, r=5)
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
dummy_memory = T.autograd.Variable(T.rand(4, 120, 32))
print(layer)
output, alignment = layer(dummy_input, dummy_memory)
print(output.shape)
assert output.shape[0] == 4
assert output.shape[1] == 120 / 5
assert output.shape[2] == 32 * 5

View File

@ -41,9 +41,8 @@ class TestDataset(unittest.TestCase):
break
text_input = data[0]
text_lengths = data[1]
print(text_lengths)
magnitude_input = data[2]
mel_input = data[3]
item_idx = data[4]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)