Testing of layers and documentation

This commit is contained in:
Eren Golge 2018-02-08 10:10:11 -08:00
parent 584c8fbf5e
commit 7d5bcd6ca4
4 changed files with 179 additions and 72 deletions

View File

@ -30,15 +30,15 @@ class BahdanauAttention(nn.Module):
return alignment.squeeze(-1) return alignment.squeeze(-1)
def get_mask_from_lengths(memory, memory_lengths): def get_mask_from_lengths(inputs, inputs_lengths):
"""Get mask tensor from list of length """Get mask tensor from list of length
Args: Args:
memory: (batch, max_time, dim) inputs: (batch, max_time, dim)
memory_lengths: array like inputs_lengths: array like
""" """
mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_() mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_()
for idx, l in enumerate(memory_lengths): for idx, l in enumerate(inputs_lengths):
mask[idx][:l] = 1 mask[idx][:l] = 1
return ~mask return ~mask
@ -51,14 +51,14 @@ class AttentionWrapper(nn.Module):
self.alignment_model = alignment_model self.alignment_model = alignment_model
self.score_mask_value = score_mask_value self.score_mask_value = score_mask_value
def forward(self, query, context_vec, cell_state, memory, def forward(self, query, context_vec, cell_state, inputs,
processed_inputs=None, mask=None, memory_lengths=None): processed_inputs=None, mask=None, inputs_lengths=None):
if processed_inputs is None: if processed_inputs is None:
processed_inputs = memory processed_inputs = inputs
if memory_lengths is not None and mask is None: if inputs_lengths is not None and mask is None:
mask = get_mask_from_lengths(memory, memory_lengths) mask = get_mask_from_lengths(inputs, inputs_lengths)
# Alignment # Alignment
# (batch, max_time) # (batch, max_time)
@ -77,7 +77,7 @@ class AttentionWrapper(nn.Module):
# Attention context vector # Attention context vector
# (batch, 1, dim) # (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
context_vec = torch.bmm(alignment.unsqueeze(1), memory) context_vec = torch.bmm(alignment.unsqueeze(1), inputs)
context_vec = context_vec.squeeze(1) context_vec = context_vec.squeeze(1)
# Concat input query and previous context_vec context # Concat input query and previous context_vec context

View File

@ -7,31 +7,55 @@ from .attention import BahdanauAttention, AttentionWrapper
from .attention import get_mask_from_lengths from .attention import get_mask_from_lengths
class Prenet(nn.Module): class Prenet(nn.Module):
def __init__(self, in_dim, sizes=[256, 128]): r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
It creates as many layers as given by 'out_features'
Args:
in_features (int): size of the input vector
out_features (int or list): size of each output sample.
If it is a list, for each value, there is created a new layer.
"""
def __init__(self, in_features, out_features=[256, 128]):
super(Prenet, self).__init__() super(Prenet, self).__init__()
in_sizes = [in_dim] + sizes[:-1] in_features = [in_features] + out_features[:-1]
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[nn.Linear(in_size, out_size) [nn.Linear(in_size, out_size)
for (in_size, out_size) in zip(in_sizes, sizes)]) for (in_size, out_size) in zip(in_features, out_features)])
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
def forward(self, inputs): def forward(self, inputs):
for linear in self.layers: for linear in self.layers:
inputs = self.dropout(self.relu(linear(inputs))) inputs = self.dropout(self.relu(linear(inputs)))
return inputs return inputs
class BatchNormConv1d(nn.Module): class BatchNormConv1d(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, padding, r"""A wrapper for Conv1d with BatchNorm. It sets the activation
function between Conv and BatchNorm layers. BatchNorm layer
is initialized with the TF default values for momentum and eps.
Args:
in_channels: size of each input sample
out_channels: size of each output samples
kernel_size: kernel size of conv filters
stride: stride of conv filters
padding: padding of conv filters
activation: activation function set b/w Conv1d and BatchNorm
Shapes:
- input: batch x dims
- output: batch x dims
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
activation=None): activation=None):
super(BatchNormConv1d, self).__init__() super(BatchNormConv1d, self).__init__()
self.conv1d = nn.Conv1d(in_dim, out_dim, self.conv1d = nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, padding=padding, bias=False) stride=stride, padding=padding, bias=False)
# Following tensorflow's default parameters # Following tensorflow's default parameters
self.bn = nn.BatchNorm1d(out_dim, momentum=0.99, eps=1e-3) self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
self.activation = activation self.activation = activation
def forward(self, x): def forward(self, x):
@ -62,86 +86,109 @@ class CBHG(nn.Module):
- 1-d convolution banks - 1-d convolution banks
- Highway networks + residual connections - Highway networks + residual connections
- Bidirectional gated recurrent units - Bidirectional gated recurrent units
Args:
in_features (int): sample size
K (int): max filter size in conv bank
projections (list): conv channel sizes for conv projections
num_highways (int): number of highways layers
Shapes:
- input: batch x time x dim
- output: batch x time x dim*2
""" """
def __init__(self, in_dim, K=16, projections=[128, 128]): def __init__(self, in_features, K=16, projections=[128, 128], num_highways=4):
super(CBHG, self).__init__() super(CBHG, self).__init__()
self.in_dim = in_dim self.in_features = in_features
self.relu = nn.ReLU() self.relu = nn.ReLU()
# list of conv1d bank with filter size k=1...K
# TODO: try dilational layers instead
self.conv1d_banks = nn.ModuleList( self.conv1d_banks = nn.ModuleList(
[BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1, [BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1,
padding=k // 2, activation=self.relu) padding=k // 2, activation=self.relu)
for k in range(1, K + 1)]) for k in range(1, K + 1)])
# max pooling of conv bank
# TODO: try average pooling OR larger kernel size
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
in_sizes = [K * in_dim] + projections[:-1] out_features = [K * in_features] + projections[:-1]
activations = [self.relu] * (len(projections) - 1) + [None] activations = [self.relu] * (len(projections) - 1)
self.conv1d_projections = nn.ModuleList( activations += [None]
[BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
# setup conv1d projection layers
layer_set = []
for (in_size, out_size, ac) in zip(out_features, projections, activations):
layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
padding=1, activation=ac) padding=1, activation=ac)
for (in_size, out_size, ac) in zip( layer_set.append(layer)
in_sizes, projections, activations)]) self.conv1d_projections = nn.ModuleList(layer_set)
self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False) # setup Highway layers
self.pre_highway = nn.Linear(projections[-1], in_features, bias=False)
self.highways = nn.ModuleList( self.highways = nn.ModuleList(
[Highway(in_dim, in_dim) for _ in range(4)]) [Highway(in_features, in_features) for _ in range(num_highways)])
# bi-directional GPU layer
self.gru = nn.GRU( self.gru = nn.GRU(
in_dim, in_dim, 1, batch_first=True, bidirectional=True) in_features, in_features, 1, batch_first=True, bidirectional=True)
def forward(self, inputs): def forward(self, inputs):
# (B, T_in, in_dim) # (B, T_in, in_features)
x = inputs x = inputs
# Needed to perform conv1d on time-axis # Needed to perform conv1d on time-axis
# (B, in_dim, T_in) # (B, in_features, T_in)
if x.size(-1) == self.in_dim: if x.size(-1) == self.in_features:
x = x.transpose(1, 2) x = x.transpose(1, 2)
T = x.size(-1) T = x.size(-1)
# (B, in_dim*K, T_in) # (B, in_features*K, T_in)
# Concat conv1d bank outputs # Concat conv1d bank outputs
x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1) outs = []
assert x.size(1) == self.in_dim * len(self.conv1d_banks) for conv1d in self.conv1d_banks:
out = conv1d(x)
out = out[:, :, :T]
outs.append(out)
x = torch.cat(outs, dim=1)
assert x.size(1) == self.in_features * len(self.conv1d_banks)
x = self.max_pool1d(x)[:, :, :T] x = self.max_pool1d(x)[:, :, :T]
for conv1d in self.conv1d_projections: for conv1d in self.conv1d_projections:
x = conv1d(x) x = conv1d(x)
# (B, T_in, in_dim) # (B, T_in, in_features)
# Back to the original shape # Back to the original shape
x = x.transpose(1, 2) x = x.transpose(1, 2)
if x.size(-1) != self.in_dim: if x.size(-1) != self.in_features:
x = self.pre_highway(x) x = self.pre_highway(x)
# Residual connection # Residual connection
# TODO: try residual scaling as in Deep Voice 3
# TODO: try plain residual layers
x += inputs x += inputs
for highway in self.highways: for highway in self.highways:
x = highway(x) x = highway(x)
# if input_lengths is not None: # (B, T_in, in_features*2)
# print(x.size()) # TODO: replace GRU with convolution as in Deep Voice 3
# print(len(input_lengths))
# x = nn.utils.rnn.pack_padded_sequence(
# x, input_lengths.data.cpu().numpy(), batch_first=True)
# (B, T_in, in_dim*2)
self.gru.flatten_parameters() self.gru.flatten_parameters()
outputs, _ = self.gru(x) outputs, _ = self.gru(x)
#if input_lengths is not None:
# outputs, _ = nn.utils.rnn.pad_packed_sequence(
# outputs, batch_first=True)
return outputs return outputs
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, in_dim): r"""Encapsulate Prenet and CBHG modules for encoder"""
def __init__(self, in_features):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.prenet = Prenet(in_dim, sizes=[256, 128]) self.prenet = Prenet(in_features, out_features=[256, 128])
self.cbhg = CBHG(128, K=16, projections=[128, 128]) self.cbhg = CBHG(128, K=16, projections=[128, 128])
def forward(self, inputs): def forward(self, inputs):
@ -150,22 +197,32 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, memory_dim, r): r"""Decoder module.
Args:
memory_dim (int): memory vector sample size
r (int): number of outputs per time step
Shape:
- input:
- output:
"""
def __init__(self, in_features, memory_dim, r):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.max_decoder_steps = 200 self.max_decoder_steps = 200
self.memory_dim = memory_dim self.memory_dim = memory_dim
self.r = r self.r = r
# input -> |Linear| -> processed_inputs # input -> |Linear| -> processed_inputs
self.input_layer = nn.Linear(256, 256, bias=False) self.input_layer = nn.Linear(in_features, 256, bias=False)
# memory -> |Prenet| -> processed_memory # memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, sizes=[256, 128]) self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
# processed_inputs, prrocessed_memory -> |Attention| -> Attention, Alignment, RNN_State # processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
self.attention_rnn = AttentionWrapper( self.attention_rnn = AttentionWrapper(
nn.GRUCell(256 + 128, 256), nn.GRUCell(in_features + 128, 256),
BahdanauAttention(256) BahdanauAttention(256)
) )
# (prenet_out | attention context) -> |Linear| -> decoder_RNN_input # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(512, 256) self.project_to_decoder_in = nn.Linear(256+in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state # decoder_RNN_input -> |RNN| -> RNN_state
self.decoder_rnns = nn.ModuleList( self.decoder_rnns = nn.ModuleList(
[nn.GRUCell(256, 256) for _ in range(2)]) [nn.GRUCell(256, 256) for _ in range(2)])
@ -173,22 +230,26 @@ class Decoder(nn.Module):
self.proj_to_mel = nn.Linear(256, memory_dim * r) self.proj_to_mel = nn.Linear(256, memory_dim * r)
def forward(self, inputs, memory=None, memory_lengths=None): def forward(self, inputs, memory=None, memory_lengths=None):
""" r"""
Decoder forward step. Decoder forward step.
If decoder inputs are not given (e.g., at testing time), as noted in If decoder inputs are not given (e.g., at testing time), as noted in
Tacotron paper, greedy decoding is adapted. Tacotron paper, greedy decoding is adapted.
Args: Args:
inputs: Encoder outputs. (B, T_encoder, dim) inputs: Encoder outputs.
memory: Decoder memory. i.e., mel-spectrogram. If None (at eval-time), memory: Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs. decoder outputs are used as decoder inputs.
memory_lengths: Encoder output (memory) lengths. If not None, used for memory_lengths: Encoder output (memory) lengths. If not None, used for
attention masking. attention masking.
Shapes:
- inputs: batch x time x encoder_out_dim
- memory: batch x #mels_pecs x mel_spec_dim
""" """
B = inputs.size(0) B = inputs.size(0)
# TODO: take thi segment into Attention module. # TODO: take this segment into Attention module.
processed_inputs = self.input_layer(inputs) processed_inputs = self.input_layer(inputs)
if memory_lengths is not None: if memory_lengths is not None:
mask = get_mask_from_lengths(processed_inputs, memory_lengths) mask = get_mask_from_lengths(processed_inputs, memory_lengths)
@ -199,9 +260,12 @@ class Decoder(nn.Module):
greedy = memory is None greedy = memory is None
if memory is not None: if memory is not None:
# Grouping multiple frames if necessary # Grouping multiple frames if necessary
if memory.size(-1) == self.memory_dim: if memory.size(-1) == self.memory_dim:
print(" > Blamento", memory.shape)
memory = memory.view(B, memory.size(1) // self.r, -1) memory = memory.view(B, memory.size(1) // self.r, -1)
print(" > Blamento", memory.shape)
assert memory.size(-1) == self.memory_dim * self.r,\ assert memory.size(-1) == self.memory_dim * self.r,\
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1), " !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
self.memory_dim, self.r) self.memory_dim, self.r)
@ -233,11 +297,11 @@ class Decoder(nn.Module):
if t > 0: if t > 0:
memory_input = outputs[-1] if greedy else memory[t - 1] memory_input = outputs[-1] if greedy else memory[t - 1]
# Prenet # Prenet
memory_input = self.prenet(memory_input) processed_memory = self.prenet(memory_input)
# Attention RNN # Attention RNN
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn( attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
memory_input, current_context_vec, attention_rnn_hidden, processed_memory, current_context_vec, attention_rnn_hidden,
inputs, processed_inputs=processed_inputs, mask=mask) inputs, processed_inputs=processed_inputs, mask=mask)
# Concat RNN output and attention context vector # Concat RNN output and attention context vector

View File

@ -1 +1,45 @@
import unittest import unittest
import torch as T
from TTS.layers.tacotron import Prenet, CBHG, Decoder
class PrenetTests(unittest.TestCase):
def test_in_out(self):
layer = Prenet(128, out_features=[256, 128])
dummy_input = T.autograd.Variable(T.rand(4, 128))
print(layer)
output = layer(dummy_input)
assert output.shape[0] == 4
assert output.shape[1] == 128
class CBHGTests(unittest.TestCase):
def test_in_out(self):
layer = CBHG(128, K= 6, projections=[128, 128], num_highways=2)
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
print(layer)
output = layer(dummy_input)
assert output.shape[0] == 4
assert output.shape[1] == 8
assert output.shape[2] == 256
class DecoderTests(unittest.TestCase):
def test_in_out(self):
layer = Decoder(in_features=128, memory_dim=32, r=5)
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
dummy_memory = T.autograd.Variable(T.rand(4, 120, 32))
print(layer)
output, alignment = layer(dummy_input, dummy_memory)
print(output.shape)
assert output.shape[0] == 4
assert output.shape[1] == 120 / 5
assert output.shape[2] == 32 * 5

View File

@ -41,9 +41,8 @@ class TestDataset(unittest.TestCase):
break break
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
print(text_lengths)
magnitude_input = data[2]
mel_input = data[3] mel_input = data[3]
item_idx = data[4]
neg_values = text_input[text_input < 0] neg_values = text_input[text_input < 0]
check_count = len(neg_values) check_count = len(neg_values)