mirror of https://github.com/coqui-ai/TTS.git
Add a constant attnetion model type to attention class
This commit is contained in:
parent
819011e1a2
commit
14f9d06b31
|
@ -219,7 +219,7 @@ class Decoder(nn.Module):
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||||
self.attention_rnn = AttentionRNN(256, in_features, 128)
|
self.attention_rnn = AttentionRNN(256, in_features, 128, align_model='ls')
|
||||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||||
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
||||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||||
|
@ -257,14 +257,15 @@ class Decoder(nn.Module):
|
||||||
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
|
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
|
||||||
self.memory_dim, self.r)
|
self.memory_dim, self.r)
|
||||||
T_decoder = memory.size(1)
|
T_decoder = memory.size(1)
|
||||||
# go frame - 0 frames tarting the sequence
|
# go frame as zeros matrix
|
||||||
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
|
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
|
||||||
# Init decoder states
|
# decoder states
|
||||||
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
||||||
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
||||||
for _ in range(len(self.decoder_rnns))]
|
for _ in range(len(self.decoder_rnns))]
|
||||||
current_context_vec = inputs.data.new(B, 256).zero_()
|
current_context_vec = inputs.data.new(B, 256).zero_()
|
||||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||||
|
# attention states
|
||||||
attention = inputs.data.new(B, T).zero_()
|
attention = inputs.data.new(B, T).zero_()
|
||||||
attention_cum = inputs.data.new(B, T).zero_()
|
attention_cum = inputs.data.new(B, T).zero_()
|
||||||
# Time first (T_decoder, B, memory_dim)
|
# Time first (T_decoder, B, memory_dim)
|
||||||
|
|
Loading…
Reference in New Issue