config update and bug fixes

This commit is contained in:
Eren Golge 2019-10-31 16:31:49 +01:00
parent adf9ebd629
commit b9e0faca98
3 changed files with 11 additions and 3 deletions

View File

@ -34,6 +34,7 @@
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
<<<<<<< HEAD
// TRAINING
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"eval_batch_size":16,
@ -47,6 +48,9 @@
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
// OPTIMIZER
=======
"model": "Tacotron2", // one of the model in models/
>>>>>>> config update and bug fixes
"grad_clip": 1, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
@ -59,8 +63,12 @@
"prenet_type": "original", // "original" or "bn".
"prenet_dropout": true, // enable/disable dropout at prenet.
<<<<<<< HEAD
// ATTENTION
"attention_type": "original", // 'original' or 'graves'
=======
"attention_type": "graves", // 'original' or 'graves'
>>>>>>> config update and bug fixes
"attention_heads": 5, // number of attention heads (only for 'graves')
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"windowing": false, // Enables attention windowing. Used only in eval mode.

View File

@ -134,11 +134,12 @@ class GravesAttention(nn.Module):
def preprocess_inputs(self, inputs):
return None
def forward(self, query, inputs, mask):
def forward(self, query, inputs, processed_inputs, mask):
"""
shapes:
query: B x D_attention_rnn
inputs: B x T_in x D_encoder
processed_inputs: place_holder
mask: B x T_in
"""
gbk_t = self.N_a(query)
@ -176,7 +177,6 @@ class GravesAttention(nn.Module):
context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1)
self.attention_weights = alpha_t
self.mu_prev = mu_t
breakpoint()
return context

View File

@ -180,7 +180,7 @@ class Decoder(nn.Module):
self.context = torch.zeros(1, device=inputs.device).repeat(
B, self.encoder_embedding_dim)
self.inputs = inputs
self.processed_inputs = self.attention.inputs_layer(inputs)
self.processed_inputs = self.attention.preprocess_inputs(inputs)
self.mask = mask
def _reshape_memory(self, memory):