mirror of https://github.com/coqui-ai/TTS.git
init with embedding lyaers
This commit is contained in:
parent
d28bbe09fb
commit
c5b6227848
|
@ -323,6 +323,10 @@ class Decoder(nn.Module):
|
|||
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||
# RNN_state -> |Linear| -> mel_spec
|
||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||
# learn init values instead of zero init.
|
||||
self.attention_rnn_init = nn.Embedding(1, 256)
|
||||
self.memory_init = nn.Embedding(1, r * memory_dim)
|
||||
self.decoder_rnn_inits = nn.Embedding(2, 256)
|
||||
self.stopnet = StopNet(256 + memory_dim * r)
|
||||
# self.init_layers()
|
||||
|
||||
|
@ -354,12 +358,12 @@ class Decoder(nn.Module):
|
|||
B = inputs.size(0)
|
||||
T = inputs.size(1)
|
||||
# go frame as zeros matrix
|
||||
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
|
||||
initial_memory = self.memory_init(inputs.data.new_zeros(B).long())
|
||||
# decoder states
|
||||
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
||||
attention_rnn_hidden = self.attention_rnn_init(inputs.data.new_zeros(B).long())
|
||||
decoder_rnn_hiddens = [
|
||||
inputs.data.new(B, 256).zero_()
|
||||
for _ in range(len(self.decoder_rnns))
|
||||
self.decoder_rnn_inits(inputs.data.new_tensor([idx]*B).long())
|
||||
for idx in range(len(self.decoder_rnns))
|
||||
]
|
||||
current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||
# attention states
|
||||
|
|
Loading…
Reference in New Issue