From 2fd37a5bad9374002873db26927694dd0d329e90 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 5 Feb 2018 06:27:02 -0800 Subject: [PATCH] Better naming of variables --- config.json | 2 +- layers/attention.py | 36 +++++++++++++++------------ layers/tacotron.py | 60 ++++++++++++++++++++++----------------------- train.py | 13 +++++----- 4 files changed, 57 insertions(+), 54 deletions(-) diff --git a/config.json b/config.json index 34b79408..26a0b3a7 100644 --- a/config.json +++ b/config.json @@ -15,7 +15,7 @@ "lr": 0.003, "lr_patience": 5, "lr_decay": 0.5, - "batch_size": 128, + "batch_size": 98, "r": 5, "griffin_lim_iters": 60, diff --git a/layers/attention.py b/layers/attention.py index 902975f4..0e24f383 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -11,11 +11,11 @@ class BahdanauAttention(nn.Module): self.tanh = nn.Tanh() self.v = nn.Linear(dim, 1, bias=False) - def forward(self, query, processed_memory): + def forward(self, query, processed_inputs): """ Args: query: (batch, 1, dim) or (batch, dim) - processed_memory: (batch, max_time, dim) + processed_inputs: (batch, max_time, dim) """ if query.dim() == 2: # insert time-axis for broadcasting @@ -24,7 +24,7 @@ class BahdanauAttention(nn.Module): processed_query = self.query_layer(query) # (batch, max_time, 1) - alignment = self.v(self.tanh(processed_query + processed_memory)) + alignment = self.v(self.tanh(processed_query + processed_inputs)) # (batch, max_time) return alignment.squeeze(-1) @@ -44,44 +44,48 @@ def get_mask_from_lengths(memory, memory_lengths): class AttentionWrapper(nn.Module): - def __init__(self, rnn_cell, attention_mechanism, + def __init__(self, rnn_cell, alignment_model, score_mask_value=-float("inf")): super(AttentionWrapper, self).__init__() self.rnn_cell = rnn_cell - self.attention_mechanism = attention_mechanism + self.alignment_model = alignment_model self.score_mask_value = score_mask_value - def forward(self, query, attention, cell_state, memory, - processed_memory=None, mask=None, memory_lengths=None): - if processed_memory is None: - processed_memory = memory + def forward(self, query, context_vec, cell_state, memory, + processed_inputs=None, mask=None, memory_lengths=None): + + if processed_inputs is None: + processed_inputs = memory if memory_lengths is not None and mask is None: mask = get_mask_from_lengths(memory, memory_lengths) - # Concat input query and previous attention context - cell_input = torch.cat((query, attention), -1) + # Concat input query and previous context_vec context + import ipdb + ipdb.set_trace() + cell_input = torch.cat((query, context_vec), -1) # Feed it to RNN cell_output = self.rnn_cell(cell_input, cell_state) # Alignment # (batch, max_time) - alignment = self.attention_mechanism(cell_output, processed_memory) + alignment = self.alignment_model(cell_output, processed_inputs) if mask is not None: mask = mask.view(query.size(0), -1) alignment.data.masked_fill_(mask, self.score_mask_value) - # Normalize attention weight + # Normalize context_vec weight alignment = F.softmax(alignment, dim=-1) # Attention context vector # (batch, 1, dim) - attention = torch.bmm(alignment.unsqueeze(1), memory) + context_vec = torch.bmm(alignment.unsqueeze(1), memory) # (batch, dim) - attention = attention.squeeze(1) + context_vec = context_vec.squeeze(1) + + return cell_output, context_vec, alignment - return cell_output, attention, alignment diff --git a/layers/tacotron.py b/layers/tacotron.py index fbefbc4f..4fbe7f17 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -152,29 +152,27 @@ class Encoder(nn.Module): class Decoder(nn.Module): def __init__(self, memory_dim, r): super(Decoder, self).__init__() + self.max_decoder_steps = 200 self.memory_dim = memory_dim self.r = r + # input -> |Linear| -> processed_inputs + self.input_layer = nn.Linear(256, 256, bias=False) + # memory -> |Prenet| -> processed_memory self.prenet = Prenet(memory_dim * r, sizes=[256, 128]) - # attetion RNN + # processed_inputs, prrocessed_memory -> |Attention| -> Attention, Alignment, RNN_State self.attention_rnn = AttentionWrapper( nn.GRUCell(256 + 128, 256), BahdanauAttention(256) ) - - self.memory_layer = nn.Linear(256, 256, bias=False) - - # concat and project context and attention vectors - # (prenet_out + attention context) -> output + # (prenet_out | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(512, 256) - - # decoder RNNs + # decoder_RNN_input -> |RNN| -> RNN_state self.decoder_rnns = nn.ModuleList( [nn.GRUCell(256, 256) for _ in range(2)]) - + # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * r) - self.max_decoder_steps = 200 - def forward(self, decoder_inputs, memory=None, memory_lengths=None): + def forward(self, inputs, memory=None, memory_lengths=None): """ Decoder forward step. @@ -182,17 +180,18 @@ class Decoder(nn.Module): Tacotron paper, greedy decoding is adapted. Args: - decoder_inputs: Encoder outputs. (B, T_encoder, dim) + inputs: Encoder outputs. (B, T_encoder, dim) memory: Decoder memory. i.e., mel-spectrogram. If None (at eval-time), decoder outputs are used as decoder inputs. memory_lengths: Encoder output (memory) lengths. If not None, used for attention masking. """ - B = decoder_inputs.size(0) + B = inputs.size(0) - processed_memory = self.memory_layer(decoder_inputs) + # TODO: take thi segment into Attention module. + processed_inputs = self.input_layer(inputs) if memory_lengths is not None: - mask = get_mask_from_lengths(processed_memory, memory_lengths) + mask = get_mask_from_lengths(processed_inputs, memory_lengths) else: mask = None @@ -208,18 +207,18 @@ class Decoder(nn.Module): self.memory_dim, self.r) T_decoder = memory.size(1) - # go frames - 0 frames tarting the sequence - initial_input = Variable( - decoder_inputs.data.new(B, self.memory_dim * self.r).zero_()) + # go frame - 0 frames tarting the sequence + initial_memory = Variable( + inputs.data.new(B, self.memory_dim * self.r).zero_()) # Init decoder states attention_rnn_hidden = Variable( - decoder_inputs.data.new(B, 256).zero_()) + inputs.data.new(B, 256).zero_()) decoder_rnn_hiddens = [Variable( - decoder_inputs.data.new(B, 256).zero_()) + inputs.data.new(B, 256).zero_()) for _ in range(len(self.decoder_rnns))] - current_attention = Variable( - decoder_inputs.data.new(B, 256).zero_()) + current_context_vec = Variable( + inputs.data.new(B, 256).zero_()) # Time first (T_decoder, B, memory_dim) if memory is not None: @@ -229,21 +228,21 @@ class Decoder(nn.Module): alignments = [] t = 0 - current_input = initial_input + memory_input = initial_memory while True: if t > 0: - current_input = outputs[-1] if greedy else memory[t - 1] + memory_input = outputs[-1] if greedy else memory[t - 1] # Prenet - current_input = self.prenet(current_input) + memory_input = self.prenet(memory_input) # Attention RNN - attention_rnn_hidden, current_attention, alignment = self.attention_rnn( - current_input, current_attention, attention_rnn_hidden, - decoder_inputs, processed_memory=processed_memory, mask=mask) + attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn( + memory_input, current_context_vec, attention_rnn_hidden, + inputs, processed_inputs=processed_inputs, mask=mask) # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( - torch.cat((attention_rnn_hidden, current_attention), -1)) + torch.cat((attention_rnn_hidden, current_context_vec), -1)) # Pass through the decoder RNNs for idx in range(len(self.decoder_rnns)): @@ -266,7 +265,8 @@ class Decoder(nn.Module): if t > 1 and is_end_of_frames(output): break elif t > self.max_decoder_steps: - print("Warning! doesn't seems to be converged") + print(" !! Decoder stopped with 'max_decoder_steps'. \ + Something is probably wrong.") break else: if t >= T_decoder: diff --git a/train.py b/train.py index 7602f81d..58097366 100644 --- a/train.py +++ b/train.py @@ -124,7 +124,7 @@ def main(args): print("\n | > Epoch {}/{}".format(epoch, c.epochs)) progbar = Progbar(len(dataset) / c.batch_size) - for i, data in enumerate(dataloader): + for num_iter, data in enumerate(dataloader): start_time = time.time() text_input = data[0] @@ -132,8 +132,7 @@ def main(args): magnitude_input = data[2] mel_input = data[3] - current_step = i + args.restore_step + epoch * len(dataloader) + 1 - print(current_step) + current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1 # setup lr current_lr = lr_decay(c.lr, current_step) @@ -190,10 +189,10 @@ def main(args): step_time = time.time() - start_time epoch_time += step_time - progbar.update(i+1, values=[('total_loss', loss.data[0]), - ('linear_loss', linear_loss.data[0]), - ('mel_loss', mel_loss.data[0]), - ('grad_norm', grad_norm)]) + progbar.update(num_iter+1, values=[('total_loss', loss.data[0]), + ('linear_loss', linear_loss.data[0]), + ('mel_loss', mel_loss.data[0]), + ('grad_norm', grad_norm)]) # Plot Learning Stats tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)