mirror of https://github.com/coqui-ai/TTS.git
config update and bug fixes
This commit is contained in:
parent
adf9ebd629
commit
b9e0faca98
|
@ -34,6 +34,7 @@
|
|||
|
||||
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
||||
|
||||
<<<<<<< HEAD
|
||||
// TRAINING
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size":16,
|
||||
|
@ -47,6 +48,9 @@
|
|||
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
|
||||
|
||||
// OPTIMIZER
|
||||
=======
|
||||
"model": "Tacotron2", // one of the model in models/
|
||||
>>>>>>> config update and bug fixes
|
||||
"grad_clip": 1, // upper limit for gradients for clipping.
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
|
@ -59,8 +63,12 @@
|
|||
"prenet_type": "original", // "original" or "bn".
|
||||
"prenet_dropout": true, // enable/disable dropout at prenet.
|
||||
|
||||
<<<<<<< HEAD
|
||||
// ATTENTION
|
||||
"attention_type": "original", // 'original' or 'graves'
|
||||
=======
|
||||
"attention_type": "graves", // 'original' or 'graves'
|
||||
>>>>>>> config update and bug fixes
|
||||
"attention_heads": 5, // number of attention heads (only for 'graves')
|
||||
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||
|
|
|
@ -134,11 +134,12 @@ class GravesAttention(nn.Module):
|
|||
def preprocess_inputs(self, inputs):
|
||||
return None
|
||||
|
||||
def forward(self, query, inputs, mask):
|
||||
def forward(self, query, inputs, processed_inputs, mask):
|
||||
"""
|
||||
shapes:
|
||||
query: B x D_attention_rnn
|
||||
inputs: B x T_in x D_encoder
|
||||
processed_inputs: place_holder
|
||||
mask: B x T_in
|
||||
"""
|
||||
gbk_t = self.N_a(query)
|
||||
|
@ -176,7 +177,6 @@ class GravesAttention(nn.Module):
|
|||
context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1)
|
||||
self.attention_weights = alpha_t
|
||||
self.mu_prev = mu_t
|
||||
breakpoint()
|
||||
return context
|
||||
|
||||
|
||||
|
|
|
@ -180,7 +180,7 @@ class Decoder(nn.Module):
|
|||
self.context = torch.zeros(1, device=inputs.device).repeat(
|
||||
B, self.encoder_embedding_dim)
|
||||
self.inputs = inputs
|
||||
self.processed_inputs = self.attention.inputs_layer(inputs)
|
||||
self.processed_inputs = self.attention.preprocess_inputs(inputs)
|
||||
self.mask = mask
|
||||
|
||||
def _reshape_memory(self, memory):
|
||||
|
|
Loading…
Reference in New Issue