docstrings for Tacotron models

This commit is contained in:
erogol 2020-07-23 16:26:20 +02:00
parent 4c7c4f77f8
commit cb4313dcb6
3 changed files with 183 additions and 80 deletions

View File

@ -18,8 +18,8 @@ class BatchNormConv1d(nn.Module):
activation: activation function set b/w Conv1d and BatchNorm
Shapes:
- input: batch x dims
- output: batch x dims
- input: (B, D)
- output: (B, D)
"""
def __init__(self,
@ -67,12 +67,23 @@ class BatchNormConv1d(nn.Module):
class Highway(nn.Module):
r"""Highway layers as explained in https://arxiv.org/abs/1505.00387
Args:
in_features (int): size of each input sample
out_feature (int): size of each output sample
Shapes:
- input: (B, *, H_in)
- output: (B, *, H_out)
"""
# TODO: Try GLU layer
def __init__(self, in_size, out_size):
def __init__(self, in_features, out_feature):
super(Highway, self).__init__()
self.H = nn.Linear(in_size, out_size)
self.H = nn.Linear(in_features, out_feature)
self.H.bias.data.zero_()
self.T = nn.Linear(in_size, out_size)
self.T = nn.Linear(in_features, out_feature)
self.T.bias.data.fill_(-1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
@ -103,8 +114,8 @@ class CBHG(nn.Module):
num_highways (int): number of highways layers
Shapes:
- input: B x D x T_in
- output: B x T_in x D*2
- input: (B, C, T_in)
- output: (B, T_in, C*2)
"""
def __init__(self,
@ -195,6 +206,8 @@ class CBHG(nn.Module):
class EncoderCBHG(nn.Module):
r"""CBHG module with Encoder specific arguments"""
def __init__(self):
super(EncoderCBHG, self).__init__()
self.cbhg = CBHG(
@ -211,7 +224,14 @@ class EncoderCBHG(nn.Module):
class Encoder(nn.Module):
r"""Encapsulate Prenet and CBHG modules for encoder"""
r"""Stack Prenet and CBHG module for encoder
Args:
inputs (FloatTensor): embedding features
Shapes:
- inputs: (B, T, D_in)
- outputs: (B, T, 128 * 2)
"""
def __init__(self, in_features):
super(Encoder, self).__init__()
@ -219,14 +239,6 @@ class Encoder(nn.Module):
self.cbhg = EncoderCBHG()
def forward(self, inputs):
r"""
Args:
inputs (FloatTensor): embedding features
Shapes:
- inputs: batch x time x in_features
- outputs: batch x time x 128*2
"""
# B x T x prenet_dim
outputs = self.prenet(inputs)
outputs = self.cbhg(outputs.transpose(1, 2))
@ -250,35 +262,48 @@ class PostCBHG(nn.Module):
class Decoder(nn.Module):
"""Decoder module.
"""Tacotron decoder.
Args:
in_features (int): input vector (encoder output) sample size.
memory_dim (int): memory vector (prev. time-step output) sample size.
r (int): number of outputs per time step.
in_channels (int): number of input channels.
frame_channels (int): number of feature frame channels.
r (int): number of outputs per time step (reduction rate).
memory_size (int): size of the past window. if <= 0 memory_size = r
TODO: arguments
attn_type (string): type of attention used in decoder.
attn_windowing (bool): if true, define an attention window centered to maximum
attention response. It provides more robust attention alignment especially
at interence time.
attn_norm (string): attention normalization function. 'sigmoid' or 'softmax'.
prenet_type (string): 'original' or 'bn'.
prenet_dropout (float): prenet dropout rate.
forward_attn (bool): if true, use forward attention method. https://arxiv.org/abs/1807.06736
trans_agent (bool): if true, use transition agent. https://arxiv.org/abs/1807.06736
forward_attn_mask (bool): if true, mask attention values smaller than a threshold.
location_attn (bool): if true, use location sensitive attention.
attn_K (int): number of attention heads for GravesAttention.
separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow.
speaker_embedding_dim (int): size of speaker embedding vector, for multi-speaker training.
"""
# Pylint gets confused by PyTorch conventions here
#pylint: disable=attribute-defined-outside-init
# pylint: disable=attribute-defined-outside-init
def __init__(self, in_features, memory_dim, r, memory_size, attn_type, attn_windowing,
def __init__(self, in_channels, frame_channels, r, memory_size, attn_type, attn_windowing,
attn_norm, prenet_type, prenet_dropout, forward_attn,
trans_agent, forward_attn_mask, location_attn, attn_K,
separate_stopnet, speaker_embedding_dim):
super(Decoder, self).__init__()
self.r_init = r
self.r = r
self.in_features = in_features
self.in_channels = in_channels
self.max_decoder_steps = 500
self.use_memory_queue = memory_size > 0
self.memory_size = memory_size if memory_size > 0 else r
self.memory_dim = memory_dim
self.frame_channels = frame_channels
self.separate_stopnet = separate_stopnet
self.query_dim = 256
# memory -> |Prenet| -> processed_memory
prenet_dim = memory_dim * self.memory_size + speaker_embedding_dim if self.use_memory_queue else memory_dim + speaker_embedding_dim
prenet_dim = frame_channels * self.memory_size + speaker_embedding_dim if self.use_memory_queue else frame_channels + speaker_embedding_dim
self.prenet = Prenet(
prenet_dim,
prenet_type,
@ -286,11 +311,11 @@ class Decoder(nn.Module):
out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
# attention_rnn generates queries for the attention mechanism
self.attention_rnn = nn.GRUCell(in_features + 128, self.query_dim)
self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim)
self.attention = init_attn(attn_type=attn_type,
query_dim=self.query_dim,
embedding_dim=in_features,
embedding_dim=in_channels,
attention_dim=128,
location_attention=location_attn,
attention_location_n_filters=32,
@ -302,14 +327,14 @@ class Decoder(nn.Module):
forward_attn_mask=forward_attn_mask,
attn_K=attn_K)
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
self.project_to_decoder_in = nn.Linear(256 + in_channels, 256)
# decoder_RNN_input -> |RNN| -> RNN_state
self.decoder_rnns = nn.ModuleList(
[nn.GRUCell(256, 256) for _ in range(2)])
# RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init)
self.proj_to_mel = nn.Linear(256, frame_channels * self.r_init)
# learn init values instead of zero init.
self.stopnet = StopNet(256 + memory_dim * self.r_init)
self.stopnet = StopNet(256 + frame_channels * self.r_init)
def set_r(self, new_r):
self.r = new_r
@ -319,9 +344,9 @@ class Decoder(nn.Module):
Reshape the spectrograms for given 'r'
"""
# Grouping multiple frames if necessary
if memory.size(-1) == self.memory_dim:
if memory.size(-1) == self.frame_channels:
memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1)
# Time first (T_decoder, B, memory_dim)
# Time first (T_decoder, B, frame_channels)
memory = memory.transpose(0, 1)
return memory
@ -333,16 +358,16 @@ class Decoder(nn.Module):
T = inputs.size(1)
# go frame as zeros matrix
if self.use_memory_queue:
self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.memory_dim * self.memory_size)
self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels * self.memory_size)
else:
self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.memory_dim)
self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels)
# decoder states
self.attention_rnn_hidden = torch.zeros(1, device=inputs.device).repeat(B, 256)
self.decoder_rnn_hiddens = [
torch.zeros(1, device=inputs.device).repeat(B, 256)
for idx in range(len(self.decoder_rnns))
]
self.context_vec = inputs.data.new(B, self.in_features).zero_()
self.context_vec = inputs.data.new(B, self.in_channels).zero_()
# cache attention inputs
self.processed_inputs = self.attention.preprocess_inputs(inputs)
@ -352,7 +377,7 @@ class Decoder(nn.Module):
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
outputs = outputs.view(
outputs.size(0), -1, self.memory_dim)
outputs.size(0), -1, self.frame_channels)
outputs = outputs.transpose(1, 2)
return outputs, attentions, stop_tokens
@ -386,7 +411,7 @@ class Decoder(nn.Module):
stop_token = self.stopnet(stopnet_input.detach())
else:
stop_token = self.stopnet(stopnet_input)
output = output[:, : self.r * self.memory_dim]
output = output[:, : self.r * self.frame_channels]
return output, stop_token, self.attention.attention_weights
def _update_memory_input(self, new_memory):
@ -395,15 +420,15 @@ class Decoder(nn.Module):
# memory queue size is larger than number of frames per decoder iter
self.memory_input = torch.cat([
new_memory, self.memory_input[:, :(
self.memory_size - self.r) * self.memory_dim].clone()
self.memory_size - self.r) * self.frame_channels].clone()
], dim=-1)
else:
# memory queue size smaller than number of frames per decoder iter
self.memory_input = new_memory[:, :self.memory_size * self.memory_dim]
self.memory_input = new_memory[:, :self.memory_size * self.frame_channels]
else:
# use only the last frame prediction
# assert new_memory.shape[-1] == self.r * self.memory_dim
self.memory_input = new_memory[:, self.memory_dim * (self.r - 1):]
# assert new_memory.shape[-1] == self.r * self.frame_channels
self.memory_input = new_memory[:, self.frame_channels * (self.r - 1):]
def forward(self, inputs, memory, mask, speaker_embeddings=None):
"""
@ -415,8 +440,8 @@ class Decoder(nn.Module):
mask: Attention mask for sequence padding.
Shapes:
- inputs: batch x time x encoder_out_dim
- memory: batch x #mel_specs x mel_spec_dim
- inputs: (B, T, D_out_enc)
- memory: (B, T_mel, D_mel)
"""
# Run greedy decoding if memory is None
memory = self._reshape_memory(memory)
@ -446,8 +471,8 @@ class Decoder(nn.Module):
speaker_embeddings: speaker vectors.
Shapes:
- inputs: batch x time x encoder_out_dim
- speaker_embeddings: batch x embed_dim
- inputs: (B, T, D_out_enc)
- speaker_embeddings: (B, D_embed)
"""
outputs = []
attentions = []
@ -478,7 +503,7 @@ class Decoder(nn.Module):
class StopNet(nn.Module):
r"""
r"""Stopnet signalling decoder to stop inference.
Args:
in_features (int): feature dimension of input.
"""

View File

@ -6,6 +6,18 @@ from .common_layers import init_attn, Prenet, Linear
class ConvBNBlock(nn.Module):
r"""Convolutions with Batch Normalization and non-linear activation.
Args:
in_channels (int): number of input channels.
out_channels (int): number of output channels.
kernel_size (int): convolution kernel size.
activation (str): 'relu', 'tanh', None (linear).
Shapes:
- input: (B, C_in, T)
- output: (B, C_out, T)
"""
def __init__(self, in_channels, out_channels, kernel_size, activation=None):
super(ConvBNBlock, self).__init__()
assert (kernel_size - 1) % 2 == 0
@ -32,16 +44,25 @@ class ConvBNBlock(nn.Module):
class Postnet(nn.Module):
def __init__(self, output_dim, num_convs=5):
r"""Tacotron2 Postnet
Args:
in_out_channels (int): number of output channels.
Shapes:
- input: (B, C_in, T)
- output: (B, C_in, T)
"""
def __init__(self, in_out_channels, num_convs=5):
super(Postnet, self).__init__()
self.convolutions = nn.ModuleList()
self.convolutions.append(
ConvBNBlock(output_dim, 512, kernel_size=5, activation='tanh'))
ConvBNBlock(in_out_channels, 512, kernel_size=5, activation='tanh'))
for _ in range(1, num_convs - 1):
self.convolutions.append(
ConvBNBlock(512, 512, kernel_size=5, activation='tanh'))
self.convolutions.append(
ConvBNBlock(512, output_dim, kernel_size=5, activation=None))
ConvBNBlock(512, in_out_channels, kernel_size=5, activation=None))
def forward(self, x):
o = x
@ -51,14 +72,23 @@ class Postnet(nn.Module):
class Encoder(nn.Module):
def __init__(self, output_input_dim=512):
r"""Tacotron2 Encoder
Args:
in_out_channels (int): number of input and output channels.
Shapes:
- input: (B, C_in, T)
- output: (B, C_in, T)
"""
def __init__(self, in_out_channels=512):
super(Encoder, self).__init__()
self.convolutions = nn.ModuleList()
for _ in range(3):
self.convolutions.append(
ConvBNBlock(output_input_dim, output_input_dim, 5, 'relu'))
self.lstm = nn.LSTM(output_input_dim,
int(output_input_dim / 2),
ConvBNBlock(in_out_channels, in_out_channels, 5, 'relu'))
self.lstm = nn.LSTM(in_out_channels,
int(in_out_channels / 2),
num_layers=1,
batch_first=True,
bias=True,
@ -90,17 +120,39 @@ class Encoder(nn.Module):
# adapted from https://github.com/NVIDIA/tacotron2/
class Decoder(nn.Module):
"""Tacotron2 decoder. We don't use Zoneout but Dropout between RNN layers.
Args:
in_channels (int): number of input channels.
frame_channels (int): number of feature frame channels.
r (int): number of outputs per time step (reduction rate).
memory_size (int): size of the past window. if <= 0 memory_size = r
attn_type (string): type of attention used in decoder.
attn_win (bool): if true, define an attention window centered to maximum
attention response. It provides more robust attention alignment especially
at interence time.
attn_norm (string): attention normalization function. 'sigmoid' or 'softmax'.
prenet_type (string): 'original' or 'bn'.
prenet_dropout (float): prenet dropout rate.
forward_attn (bool): if true, use forward attention method. https://arxiv.org/abs/1807.06736
trans_agent (bool): if true, use transition agent. https://arxiv.org/abs/1807.06736
forward_attn_mask (bool): if true, mask attention values smaller than a threshold.
location_attn (bool): if true, use location sensitive attention.
attn_K (int): number of attention heads for GravesAttention.
separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow.
speaker_embedding_dim (int): size of speaker embedding vector, for multi-speaker training.
"""
# Pylint gets confused by PyTorch conventions here
#pylint: disable=attribute-defined-outside-init
def __init__(self, input_dim, frame_dim, r, attn_type, attn_win, attn_norm,
def __init__(self, in_channels, frame_channels, r, attn_type, attn_win, attn_norm,
prenet_type, prenet_dropout, forward_attn, trans_agent,
forward_attn_mask, location_attn, attn_K, separate_stopnet,
speaker_embedding_dim):
super(Decoder, self).__init__()
self.frame_dim = frame_dim
self.frame_channels = frame_channels
self.r_init = r
self.r = r
self.encoder_embedding_dim = input_dim
self.encoder_embedding_dim = in_channels
self.separate_stopnet = separate_stopnet
self.max_decoder_steps = 1000
self.stop_threshold = 0.5
@ -114,20 +166,20 @@ class Decoder(nn.Module):
self.p_decoder_dropout = 0.1
# memory -> |Prenet| -> processed_memory
prenet_dim = self.frame_dim
prenet_dim = self.frame_channels
self.prenet = Prenet(prenet_dim,
prenet_type,
prenet_dropout,
out_features=[self.prenet_dim, self.prenet_dim],
bias=False)
self.attention_rnn = nn.LSTMCell(self.prenet_dim + input_dim,
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_channels,
self.query_dim,
bias=True)
self.attention = init_attn(attn_type=attn_type,
query_dim=self.query_dim,
embedding_dim=input_dim,
embedding_dim=in_channels,
attention_dim=128,
location_attention=location_attn,
attention_location_n_filters=32,
@ -139,16 +191,16 @@ class Decoder(nn.Module):
forward_attn_mask=forward_attn_mask,
attn_K=attn_K)
self.decoder_rnn = nn.LSTMCell(self.query_dim + input_dim,
self.decoder_rnn = nn.LSTMCell(self.query_dim + in_channels,
self.decoder_rnn_dim,
bias=True)
self.linear_projection = Linear(self.decoder_rnn_dim + input_dim,
self.frame_dim * self.r_init)
self.linear_projection = Linear(self.decoder_rnn_dim + in_channels,
self.frame_channels * self.r_init)
self.stopnet = nn.Sequential(
nn.Dropout(0.1),
Linear(self.decoder_rnn_dim + self.frame_dim * self.r_init,
Linear(self.decoder_rnn_dim + self.frame_channels * self.r_init,
1,
bias=True,
init_gain='sigmoid'))
@ -160,7 +212,7 @@ class Decoder(nn.Module):
def get_go_frame(self, inputs):
B = inputs.size(0)
memory = torch.zeros(1, device=inputs.device).repeat(B,
self.frame_dim * self.r)
self.frame_channels * self.r)
return memory
def _init_states(self, inputs, mask, keep_states=False):
@ -186,9 +238,9 @@ class Decoder(nn.Module):
Reshape the spectrograms for given 'r'
"""
# Grouping multiple frames if necessary
if memory.size(-1) == self.frame_dim:
if memory.size(-1) == self.frame_channels:
memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1)
# Time first (T_decoder, B, frame_dim)
# Time first (T_decoder, B, frame_channels)
memory = memory.transpose(0, 1)
return memory
@ -196,22 +248,22 @@ class Decoder(nn.Module):
alignments = torch.stack(alignments).transpose(0, 1)
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
outputs = outputs.view(outputs.size(0), -1, self.frame_dim)
outputs = outputs.view(outputs.size(0), -1, self.frame_channels)
outputs = outputs.transpose(1, 2)
return outputs, stop_tokens, alignments
def _update_memory(self, memory):
if len(memory.shape) == 2:
return memory[:, self.frame_dim * (self.r - 1):]
return memory[:, :, self.frame_dim * (self.r - 1):]
return memory[:, self.frame_channels * (self.r - 1):]
return memory[:, :, self.frame_channels * (self.r - 1):]
def decode(self, memory):
'''
shapes:
- memory: B x r * self.frame_dim
- memory: B x r * self.frame_channels
'''
# self.context: B x D_en
# query_input: B x D_en + (r * self.frame_dim)
# query_input: B x D_en + (r * self.frame_channels)
query_input = torch.cat((memory, self.context), -1)
# self.query and self.attention_rnn_cell_state : B x D_attn_rnn
self.query, self.attention_rnn_cell_state = self.attention_rnn(
@ -234,19 +286,32 @@ class Decoder(nn.Module):
# B x (D_decoder_rnn + D_en)
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context),
dim=1)
# B x (self.r * self.frame_dim)
# B x (self.r * self.frame_channels)
decoder_output = self.linear_projection(decoder_hidden_context)
# B x (D_decoder_rnn + (self.r * self.frame_dim))
# B x (D_decoder_rnn + (self.r * self.frame_channels))
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
if self.separate_stopnet:
stop_token = self.stopnet(stopnet_input.detach())
else:
stop_token = self.stopnet(stopnet_input)
# select outputs for the reduction rate self.r
decoder_output = decoder_output[:, :self.r * self.frame_dim]
decoder_output = decoder_output[:, :self.r * self.frame_channels]
return decoder_output, self.attention.attention_weights, stop_token
def forward(self, inputs, memories, mask, speaker_embeddings=None):
r"""Train Decoder with teacher forcing.
Args:
inputs: Encoder outputs.
memories: Feature frames for teacher-forcing.
mask: Attention mask for sequence padding.
Shapes:
- inputs: (B, T, D_out_enc)
- memory: (B, T_mel, D_mel)
- outputs: (B, T_mel, D_mel)
- alignments: (B, T_in, T_out)
- stop_tokens: (B, T_out)
"""
memory = self.get_go_frame(inputs).unsqueeze(0)
memories = self._reshape_memory(memories)
memories = torch.cat((memory, memories), dim=0)
@ -271,6 +336,19 @@ class Decoder(nn.Module):
return outputs, alignments, stop_tokens
def inference(self, inputs, speaker_embeddings=None):
r"""Decoder inference without teacher forcing and use
Stopnet to stop decoder.
Args:
inputs: Encoder outputs.
speaker_embeddings: speaker embedding vectors.
Shapes:
- inputs: (B, T, D_out_enc)
- speaker_embeddings: (B, D_embed)
- outputs: (B, T_mel, D_mel)
- alignments: (B, T_in, T_out)
- stop_tokens: (B, T_out)
"""
memory = self.get_go_frame(inputs)
memory = self._update_memory(memory)

View File

@ -44,8 +44,8 @@ class DecoderTests(unittest.TestCase):
@staticmethod
def test_in_out():
layer = Decoder(
in_features=256,
memory_dim=80,
in_channels=256,
frame_channels=80,
r=2,
memory_size=4,
attn_windowing=False,
@ -74,8 +74,8 @@ class DecoderTests(unittest.TestCase):
@staticmethod
def test_in_out_multispeaker():
layer = Decoder(
in_features=256,
memory_dim=80,
in_channels=256,
frame_channels=80,
r=2,
memory_size=4,
attn_windowing=False,