mirror of https://github.com/coqui-ai/TTS.git
Testing of layers and documentation
This commit is contained in:
parent
584c8fbf5e
commit
7d5bcd6ca4
|
@ -30,15 +30,15 @@ class BahdanauAttention(nn.Module):
|
|||
return alignment.squeeze(-1)
|
||||
|
||||
|
||||
def get_mask_from_lengths(memory, memory_lengths):
|
||||
def get_mask_from_lengths(inputs, inputs_lengths):
|
||||
"""Get mask tensor from list of length
|
||||
|
||||
Args:
|
||||
memory: (batch, max_time, dim)
|
||||
memory_lengths: array like
|
||||
inputs: (batch, max_time, dim)
|
||||
inputs_lengths: array like
|
||||
"""
|
||||
mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_()
|
||||
for idx, l in enumerate(memory_lengths):
|
||||
mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_()
|
||||
for idx, l in enumerate(inputs_lengths):
|
||||
mask[idx][:l] = 1
|
||||
return ~mask
|
||||
|
||||
|
@ -51,14 +51,14 @@ class AttentionWrapper(nn.Module):
|
|||
self.alignment_model = alignment_model
|
||||
self.score_mask_value = score_mask_value
|
||||
|
||||
def forward(self, query, context_vec, cell_state, memory,
|
||||
processed_inputs=None, mask=None, memory_lengths=None):
|
||||
def forward(self, query, context_vec, cell_state, inputs,
|
||||
processed_inputs=None, mask=None, inputs_lengths=None):
|
||||
|
||||
if processed_inputs is None:
|
||||
processed_inputs = memory
|
||||
processed_inputs = inputs
|
||||
|
||||
if memory_lengths is not None and mask is None:
|
||||
mask = get_mask_from_lengths(memory, memory_lengths)
|
||||
if inputs_lengths is not None and mask is None:
|
||||
mask = get_mask_from_lengths(inputs, inputs_lengths)
|
||||
|
||||
# Alignment
|
||||
# (batch, max_time)
|
||||
|
@ -77,7 +77,7 @@ class AttentionWrapper(nn.Module):
|
|||
# Attention context vector
|
||||
# (batch, 1, dim)
|
||||
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
||||
context_vec = torch.bmm(alignment.unsqueeze(1), memory)
|
||||
context_vec = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||
context_vec = context_vec.squeeze(1)
|
||||
|
||||
# Concat input query and previous context_vec context
|
||||
|
|
|
@ -7,35 +7,59 @@ from .attention import BahdanauAttention, AttentionWrapper
|
|||
from .attention import get_mask_from_lengths
|
||||
|
||||
class Prenet(nn.Module):
|
||||
def __init__(self, in_dim, sizes=[256, 128]):
|
||||
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
||||
It creates as many layers as given by 'out_features'
|
||||
|
||||
Args:
|
||||
in_features (int): size of the input vector
|
||||
out_features (int or list): size of each output sample.
|
||||
If it is a list, for each value, there is created a new layer.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, out_features=[256, 128]):
|
||||
super(Prenet, self).__init__()
|
||||
in_sizes = [in_dim] + sizes[:-1]
|
||||
in_features = [in_features] + out_features[:-1]
|
||||
self.layers = nn.ModuleList(
|
||||
[nn.Linear(in_size, out_size)
|
||||
for (in_size, out_size) in zip(in_sizes, sizes)])
|
||||
for (in_size, out_size) in zip(in_features, out_features)])
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
|
||||
def forward(self, inputs):
|
||||
for linear in self.layers:
|
||||
inputs = self.dropout(self.relu(linear(inputs)))
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class BatchNormConv1d(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, kernel_size, stride, padding,
|
||||
r"""A wrapper for Conv1d with BatchNorm. It sets the activation
|
||||
function between Conv and BatchNorm layers. BatchNorm layer
|
||||
is initialized with the TF default values for momentum and eps.
|
||||
|
||||
Args:
|
||||
in_channels: size of each input sample
|
||||
out_channels: size of each output samples
|
||||
kernel_size: kernel size of conv filters
|
||||
stride: stride of conv filters
|
||||
padding: padding of conv filters
|
||||
activation: activation function set b/w Conv1d and BatchNorm
|
||||
|
||||
Shapes:
|
||||
- input: batch x dims
|
||||
- output: batch x dims
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
|
||||
activation=None):
|
||||
super(BatchNormConv1d, self).__init__()
|
||||
self.conv1d = nn.Conv1d(in_dim, out_dim,
|
||||
self.conv1d = nn.Conv1d(in_channels, out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride, padding=padding, bias=False)
|
||||
# Following tensorflow's default parameters
|
||||
self.bn = nn.BatchNorm1d(out_dim, momentum=0.99, eps=1e-3)
|
||||
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1d(x)
|
||||
x = self.conv1d(x)
|
||||
if self.activation is not None:
|
||||
x = self.activation(x)
|
||||
return self.bn(x)
|
||||
|
@ -62,86 +86,109 @@ class CBHG(nn.Module):
|
|||
- 1-d convolution banks
|
||||
- Highway networks + residual connections
|
||||
- Bidirectional gated recurrent units
|
||||
|
||||
Args:
|
||||
in_features (int): sample size
|
||||
K (int): max filter size in conv bank
|
||||
projections (list): conv channel sizes for conv projections
|
||||
num_highways (int): number of highways layers
|
||||
|
||||
Shapes:
|
||||
- input: batch x time x dim
|
||||
- output: batch x time x dim*2
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim, K=16, projections=[128, 128]):
|
||||
def __init__(self, in_features, K=16, projections=[128, 128], num_highways=4):
|
||||
super(CBHG, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.in_features = in_features
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
# list of conv1d bank with filter size k=1...K
|
||||
# TODO: try dilational layers instead
|
||||
self.conv1d_banks = nn.ModuleList(
|
||||
[BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1,
|
||||
[BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1,
|
||||
padding=k // 2, activation=self.relu)
|
||||
for k in range(1, K + 1)])
|
||||
for k in range(1, K + 1)])
|
||||
|
||||
# max pooling of conv bank
|
||||
# TODO: try average pooling OR larger kernel size
|
||||
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
||||
|
||||
in_sizes = [K * in_dim] + projections[:-1]
|
||||
activations = [self.relu] * (len(projections) - 1) + [None]
|
||||
self.conv1d_projections = nn.ModuleList(
|
||||
[BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
|
||||
padding=1, activation=ac)
|
||||
for (in_size, out_size, ac) in zip(
|
||||
in_sizes, projections, activations)])
|
||||
out_features = [K * in_features] + projections[:-1]
|
||||
activations = [self.relu] * (len(projections) - 1)
|
||||
activations += [None]
|
||||
|
||||
self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False)
|
||||
# setup conv1d projection layers
|
||||
layer_set = []
|
||||
for (in_size, out_size, ac) in zip(out_features, projections, activations):
|
||||
layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
|
||||
padding=1, activation=ac)
|
||||
layer_set.append(layer)
|
||||
self.conv1d_projections = nn.ModuleList(layer_set)
|
||||
|
||||
# setup Highway layers
|
||||
self.pre_highway = nn.Linear(projections[-1], in_features, bias=False)
|
||||
self.highways = nn.ModuleList(
|
||||
[Highway(in_dim, in_dim) for _ in range(4)])
|
||||
[Highway(in_features, in_features) for _ in range(num_highways)])
|
||||
|
||||
# bi-directional GPU layer
|
||||
self.gru = nn.GRU(
|
||||
in_dim, in_dim, 1, batch_first=True, bidirectional=True)
|
||||
in_features, in_features, 1, batch_first=True, bidirectional=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
# (B, T_in, in_dim)
|
||||
# (B, T_in, in_features)
|
||||
x = inputs
|
||||
|
||||
# Needed to perform conv1d on time-axis
|
||||
# (B, in_dim, T_in)
|
||||
if x.size(-1) == self.in_dim:
|
||||
# (B, in_features, T_in)
|
||||
if x.size(-1) == self.in_features:
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
T = x.size(-1)
|
||||
|
||||
# (B, in_dim*K, T_in)
|
||||
# (B, in_features*K, T_in)
|
||||
# Concat conv1d bank outputs
|
||||
x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1)
|
||||
assert x.size(1) == self.in_dim * len(self.conv1d_banks)
|
||||
outs = []
|
||||
for conv1d in self.conv1d_banks:
|
||||
out = conv1d(x)
|
||||
out = out[:, :, :T]
|
||||
outs.append(out)
|
||||
|
||||
x = torch.cat(outs, dim=1)
|
||||
assert x.size(1) == self.in_features * len(self.conv1d_banks)
|
||||
|
||||
x = self.max_pool1d(x)[:, :, :T]
|
||||
|
||||
for conv1d in self.conv1d_projections:
|
||||
x = conv1d(x)
|
||||
|
||||
# (B, T_in, in_dim)
|
||||
# (B, T_in, in_features)
|
||||
# Back to the original shape
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
if x.size(-1) != self.in_dim:
|
||||
if x.size(-1) != self.in_features:
|
||||
x = self.pre_highway(x)
|
||||
|
||||
# Residual connection
|
||||
# TODO: try residual scaling as in Deep Voice 3
|
||||
# TODO: try plain residual layers
|
||||
x += inputs
|
||||
for highway in self.highways:
|
||||
x = highway(x)
|
||||
|
||||
# if input_lengths is not None:
|
||||
# print(x.size())
|
||||
# print(len(input_lengths))
|
||||
# x = nn.utils.rnn.pack_padded_sequence(
|
||||
# x, input_lengths.data.cpu().numpy(), batch_first=True)
|
||||
|
||||
# (B, T_in, in_dim*2)
|
||||
self.gru.flatten_parameters()
|
||||
# (B, T_in, in_features*2)
|
||||
# TODO: replace GRU with convolution as in Deep Voice 3
|
||||
self.gru.flatten_parameters()
|
||||
outputs, _ = self.gru(x)
|
||||
|
||||
#if input_lengths is not None:
|
||||
# outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
# outputs, batch_first=True)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_dim):
|
||||
r"""Encapsulate Prenet and CBHG modules for encoder"""
|
||||
|
||||
def __init__(self, in_features):
|
||||
super(Encoder, self).__init__()
|
||||
self.prenet = Prenet(in_dim, sizes=[256, 128])
|
||||
self.prenet = Prenet(in_features, out_features=[256, 128])
|
||||
self.cbhg = CBHG(128, K=16, projections=[128, 128])
|
||||
|
||||
def forward(self, inputs):
|
||||
|
@ -150,22 +197,32 @@ class Encoder(nn.Module):
|
|||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, memory_dim, r):
|
||||
r"""Decoder module.
|
||||
|
||||
Args:
|
||||
memory_dim (int): memory vector sample size
|
||||
r (int): number of outputs per time step
|
||||
|
||||
Shape:
|
||||
- input:
|
||||
- output:
|
||||
"""
|
||||
def __init__(self, in_features, memory_dim, r):
|
||||
super(Decoder, self).__init__()
|
||||
self.max_decoder_steps = 200
|
||||
self.memory_dim = memory_dim
|
||||
self.r = r
|
||||
# input -> |Linear| -> processed_inputs
|
||||
self.input_layer = nn.Linear(256, 256, bias=False)
|
||||
self.input_layer = nn.Linear(in_features, 256, bias=False)
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(memory_dim * r, sizes=[256, 128])
|
||||
# processed_inputs, prrocessed_memory -> |Attention| -> Attention, Alignment, RNN_State
|
||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
|
||||
self.attention_rnn = AttentionWrapper(
|
||||
nn.GRUCell(256 + 128, 256),
|
||||
nn.GRUCell(in_features + 128, 256),
|
||||
BahdanauAttention(256)
|
||||
)
|
||||
# (prenet_out | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(512, 256)
|
||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||
self.decoder_rnns = nn.ModuleList(
|
||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||
|
@ -173,22 +230,26 @@ class Decoder(nn.Module):
|
|||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||
|
||||
def forward(self, inputs, memory=None, memory_lengths=None):
|
||||
"""
|
||||
r"""
|
||||
Decoder forward step.
|
||||
|
||||
If decoder inputs are not given (e.g., at testing time), as noted in
|
||||
Tacotron paper, greedy decoding is adapted.
|
||||
|
||||
Args:
|
||||
inputs: Encoder outputs. (B, T_encoder, dim)
|
||||
memory: Decoder memory. i.e., mel-spectrogram. If None (at eval-time),
|
||||
inputs: Encoder outputs.
|
||||
memory: Decoder memory (autoregression. If None (at eval-time),
|
||||
decoder outputs are used as decoder inputs.
|
||||
memory_lengths: Encoder output (memory) lengths. If not None, used for
|
||||
attention masking.
|
||||
|
||||
Shapes:
|
||||
- inputs: batch x time x encoder_out_dim
|
||||
- memory: batch x #mels_pecs x mel_spec_dim
|
||||
"""
|
||||
B = inputs.size(0)
|
||||
|
||||
# TODO: take thi segment into Attention module.
|
||||
# TODO: take this segment into Attention module.
|
||||
processed_inputs = self.input_layer(inputs)
|
||||
if memory_lengths is not None:
|
||||
mask = get_mask_from_lengths(processed_inputs, memory_lengths)
|
||||
|
@ -199,9 +260,12 @@ class Decoder(nn.Module):
|
|||
greedy = memory is None
|
||||
|
||||
if memory is not None:
|
||||
|
||||
# Grouping multiple frames if necessary
|
||||
if memory.size(-1) == self.memory_dim:
|
||||
print(" > Blamento", memory.shape)
|
||||
memory = memory.view(B, memory.size(1) // self.r, -1)
|
||||
print(" > Blamento", memory.shape)
|
||||
assert memory.size(-1) == self.memory_dim * self.r,\
|
||||
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
|
||||
self.memory_dim, self.r)
|
||||
|
@ -233,11 +297,11 @@ class Decoder(nn.Module):
|
|||
if t > 0:
|
||||
memory_input = outputs[-1] if greedy else memory[t - 1]
|
||||
# Prenet
|
||||
memory_input = self.prenet(memory_input)
|
||||
processed_memory = self.prenet(memory_input)
|
||||
|
||||
# Attention RNN
|
||||
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
|
||||
memory_input, current_context_vec, attention_rnn_hidden,
|
||||
processed_memory, current_context_vec, attention_rnn_hidden,
|
||||
inputs, processed_inputs=processed_inputs, mask=mask)
|
||||
|
||||
# Concat RNN output and attention context vector
|
||||
|
|
|
@ -1 +1,45 @@
|
|||
import unittest
|
||||
import unittest
|
||||
import torch as T
|
||||
|
||||
from TTS.layers.tacotron import Prenet, CBHG, Decoder
|
||||
|
||||
|
||||
class PrenetTests(unittest.TestCase):
|
||||
|
||||
def test_in_out(self):
|
||||
layer = Prenet(128, out_features=[256, 128])
|
||||
dummy_input = T.autograd.Variable(T.rand(4, 128))
|
||||
|
||||
|
||||
print(layer)
|
||||
output = layer(dummy_input)
|
||||
assert output.shape[0] == 4
|
||||
assert output.shape[1] == 128
|
||||
|
||||
|
||||
class CBHGTests(unittest.TestCase):
|
||||
|
||||
def test_in_out(self):
|
||||
layer = CBHG(128, K= 6, projections=[128, 128], num_highways=2)
|
||||
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
||||
|
||||
print(layer)
|
||||
output = layer(dummy_input)
|
||||
assert output.shape[0] == 4
|
||||
assert output.shape[1] == 8
|
||||
assert output.shape[2] == 256
|
||||
|
||||
|
||||
class DecoderTests(unittest.TestCase):
|
||||
|
||||
def test_in_out(self):
|
||||
layer = Decoder(in_features=128, memory_dim=32, r=5)
|
||||
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
||||
dummy_memory = T.autograd.Variable(T.rand(4, 120, 32))
|
||||
|
||||
print(layer)
|
||||
output, alignment = layer(dummy_input, dummy_memory)
|
||||
print(output.shape)
|
||||
assert output.shape[0] == 4
|
||||
assert output.shape[1] == 120 / 5
|
||||
assert output.shape[2] == 32 * 5
|
||||
|
|
|
@ -41,9 +41,8 @@ class TestDataset(unittest.TestCase):
|
|||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
print(text_lengths)
|
||||
magnitude_input = data[2]
|
||||
mel_input = data[3]
|
||||
item_idx = data[4]
|
||||
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
|
|
Loading…
Reference in New Issue