Better naming of variables

This commit is contained in:
Eren Golge 2018-02-05 06:27:02 -08:00
parent 9f5e102473
commit 2fd37a5bad
4 changed files with 57 additions and 54 deletions

View File

@ -15,7 +15,7 @@
"lr": 0.003,
"lr_patience": 5,
"lr_decay": 0.5,
"batch_size": 128,
"batch_size": 98,
"r": 5,
"griffin_lim_iters": 60,

View File

@ -11,11 +11,11 @@ class BahdanauAttention(nn.Module):
self.tanh = nn.Tanh()
self.v = nn.Linear(dim, 1, bias=False)
def forward(self, query, processed_memory):
def forward(self, query, processed_inputs):
"""
Args:
query: (batch, 1, dim) or (batch, dim)
processed_memory: (batch, max_time, dim)
processed_inputs: (batch, max_time, dim)
"""
if query.dim() == 2:
# insert time-axis for broadcasting
@ -24,7 +24,7 @@ class BahdanauAttention(nn.Module):
processed_query = self.query_layer(query)
# (batch, max_time, 1)
alignment = self.v(self.tanh(processed_query + processed_memory))
alignment = self.v(self.tanh(processed_query + processed_inputs))
# (batch, max_time)
return alignment.squeeze(-1)
@ -44,44 +44,48 @@ def get_mask_from_lengths(memory, memory_lengths):
class AttentionWrapper(nn.Module):
def __init__(self, rnn_cell, attention_mechanism,
def __init__(self, rnn_cell, alignment_model,
score_mask_value=-float("inf")):
super(AttentionWrapper, self).__init__()
self.rnn_cell = rnn_cell
self.attention_mechanism = attention_mechanism
self.alignment_model = alignment_model
self.score_mask_value = score_mask_value
def forward(self, query, attention, cell_state, memory,
processed_memory=None, mask=None, memory_lengths=None):
if processed_memory is None:
processed_memory = memory
def forward(self, query, context_vec, cell_state, memory,
processed_inputs=None, mask=None, memory_lengths=None):
if processed_inputs is None:
processed_inputs = memory
if memory_lengths is not None and mask is None:
mask = get_mask_from_lengths(memory, memory_lengths)
# Concat input query and previous attention context
cell_input = torch.cat((query, attention), -1)
# Concat input query and previous context_vec context
import ipdb
ipdb.set_trace()
cell_input = torch.cat((query, context_vec), -1)
# Feed it to RNN
cell_output = self.rnn_cell(cell_input, cell_state)
# Alignment
# (batch, max_time)
alignment = self.attention_mechanism(cell_output, processed_memory)
alignment = self.alignment_model(cell_output, processed_inputs)
if mask is not None:
mask = mask.view(query.size(0), -1)
alignment.data.masked_fill_(mask, self.score_mask_value)
# Normalize attention weight
# Normalize context_vec weight
alignment = F.softmax(alignment, dim=-1)
# Attention context vector
# (batch, 1, dim)
attention = torch.bmm(alignment.unsqueeze(1), memory)
context_vec = torch.bmm(alignment.unsqueeze(1), memory)
# (batch, dim)
attention = attention.squeeze(1)
context_vec = context_vec.squeeze(1)
return cell_output, context_vec, alignment
return cell_output, attention, alignment

View File

@ -152,29 +152,27 @@ class Encoder(nn.Module):
class Decoder(nn.Module):
def __init__(self, memory_dim, r):
super(Decoder, self).__init__()
self.max_decoder_steps = 200
self.memory_dim = memory_dim
self.r = r
# input -> |Linear| -> processed_inputs
self.input_layer = nn.Linear(256, 256, bias=False)
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, sizes=[256, 128])
# attetion RNN
# processed_inputs, prrocessed_memory -> |Attention| -> Attention, Alignment, RNN_State
self.attention_rnn = AttentionWrapper(
nn.GRUCell(256 + 128, 256),
BahdanauAttention(256)
)
self.memory_layer = nn.Linear(256, 256, bias=False)
# concat and project context and attention vectors
# (prenet_out + attention context) -> output
# (prenet_out | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(512, 256)
# decoder RNNs
# decoder_RNN_input -> |RNN| -> RNN_state
self.decoder_rnns = nn.ModuleList(
[nn.GRUCell(256, 256) for _ in range(2)])
# RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.max_decoder_steps = 200
def forward(self, decoder_inputs, memory=None, memory_lengths=None):
def forward(self, inputs, memory=None, memory_lengths=None):
"""
Decoder forward step.
@ -182,17 +180,18 @@ class Decoder(nn.Module):
Tacotron paper, greedy decoding is adapted.
Args:
decoder_inputs: Encoder outputs. (B, T_encoder, dim)
inputs: Encoder outputs. (B, T_encoder, dim)
memory: Decoder memory. i.e., mel-spectrogram. If None (at eval-time),
decoder outputs are used as decoder inputs.
memory_lengths: Encoder output (memory) lengths. If not None, used for
attention masking.
"""
B = decoder_inputs.size(0)
B = inputs.size(0)
processed_memory = self.memory_layer(decoder_inputs)
# TODO: take thi segment into Attention module.
processed_inputs = self.input_layer(inputs)
if memory_lengths is not None:
mask = get_mask_from_lengths(processed_memory, memory_lengths)
mask = get_mask_from_lengths(processed_inputs, memory_lengths)
else:
mask = None
@ -208,18 +207,18 @@ class Decoder(nn.Module):
self.memory_dim, self.r)
T_decoder = memory.size(1)
# go frames - 0 frames tarting the sequence
initial_input = Variable(
decoder_inputs.data.new(B, self.memory_dim * self.r).zero_())
# go frame - 0 frames tarting the sequence
initial_memory = Variable(
inputs.data.new(B, self.memory_dim * self.r).zero_())
# Init decoder states
attention_rnn_hidden = Variable(
decoder_inputs.data.new(B, 256).zero_())
inputs.data.new(B, 256).zero_())
decoder_rnn_hiddens = [Variable(
decoder_inputs.data.new(B, 256).zero_())
inputs.data.new(B, 256).zero_())
for _ in range(len(self.decoder_rnns))]
current_attention = Variable(
decoder_inputs.data.new(B, 256).zero_())
current_context_vec = Variable(
inputs.data.new(B, 256).zero_())
# Time first (T_decoder, B, memory_dim)
if memory is not None:
@ -229,21 +228,21 @@ class Decoder(nn.Module):
alignments = []
t = 0
current_input = initial_input
memory_input = initial_memory
while True:
if t > 0:
current_input = outputs[-1] if greedy else memory[t - 1]
memory_input = outputs[-1] if greedy else memory[t - 1]
# Prenet
current_input = self.prenet(current_input)
memory_input = self.prenet(memory_input)
# Attention RNN
attention_rnn_hidden, current_attention, alignment = self.attention_rnn(
current_input, current_attention, attention_rnn_hidden,
decoder_inputs, processed_memory=processed_memory, mask=mask)
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
memory_input, current_context_vec, attention_rnn_hidden,
inputs, processed_inputs=processed_inputs, mask=mask)
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((attention_rnn_hidden, current_attention), -1))
torch.cat((attention_rnn_hidden, current_context_vec), -1))
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
@ -266,7 +265,8 @@ class Decoder(nn.Module):
if t > 1 and is_end_of_frames(output):
break
elif t > self.max_decoder_steps:
print("Warning! doesn't seems to be converged")
print(" !! Decoder stopped with 'max_decoder_steps'. \
Something is probably wrong.")
break
else:
if t >= T_decoder:

View File

@ -124,7 +124,7 @@ def main(args):
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
progbar = Progbar(len(dataset) / c.batch_size)
for i, data in enumerate(dataloader):
for num_iter, data in enumerate(dataloader):
start_time = time.time()
text_input = data[0]
@ -132,8 +132,7 @@ def main(args):
magnitude_input = data[2]
mel_input = data[3]
current_step = i + args.restore_step + epoch * len(dataloader) + 1
print(current_step)
current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1
# setup lr
current_lr = lr_decay(c.lr, current_step)
@ -190,10 +189,10 @@ def main(args):
step_time = time.time() - start_time
epoch_time += step_time
progbar.update(i+1, values=[('total_loss', loss.data[0]),
('linear_loss', linear_loss.data[0]),
('mel_loss', mel_loss.data[0]),
('grad_norm', grad_norm)])
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
('linear_loss', linear_loss.data[0]),
('mel_loss', mel_loss.data[0]),
('grad_norm', grad_norm)])
# Plot Learning Stats
tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)