mirror of https://github.com/coqui-ai/TTS.git
renamed attention_rnn to query_rnn
This commit is contained in:
parent
f08b31fd8d
commit
98edb7a4f8
|
@ -108,19 +108,19 @@ class LocationLayer(nn.Module):
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
# Pylint gets confused by PyTorch conventions here
|
# Pylint gets confused by PyTorch conventions here
|
||||||
#pylint: disable=attribute-defined-outside-init
|
#pylint: disable=attribute-defined-outside-init
|
||||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
def __init__(self, query_dim, embedding_dim, attention_dim,
|
||||||
location_attention, attention_location_n_filters,
|
location_attention, attention_location_n_filters,
|
||||||
attention_location_kernel_size, windowing, norm, forward_attn,
|
attention_location_kernel_size, windowing, norm, forward_attn,
|
||||||
trans_agent, forward_attn_mask):
|
trans_agent, forward_attn_mask):
|
||||||
super(Attention, self).__init__()
|
super(Attention, self).__init__()
|
||||||
self.query_layer = Linear(
|
self.query_layer = Linear(
|
||||||
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
query_dim, attention_dim, bias=False, init_gain='tanh')
|
||||||
self.inputs_layer = Linear(
|
self.inputs_layer = Linear(
|
||||||
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
||||||
self.v = Linear(attention_dim, 1, bias=True)
|
self.v = Linear(attention_dim, 1, bias=True)
|
||||||
if trans_agent:
|
if trans_agent:
|
||||||
self.ta = nn.Linear(
|
self.ta = nn.Linear(
|
||||||
attention_rnn_dim + embedding_dim, 1, bias=True)
|
query_dim + embedding_dim, 1, bias=True)
|
||||||
if location_attention:
|
if location_attention:
|
||||||
self.location_layer = LocationLayer(
|
self.location_layer = LocationLayer(
|
||||||
attention_dim,
|
attention_dim,
|
||||||
|
@ -203,11 +203,12 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
def apply_forward_attention(self, inputs, alignment, query):
|
def apply_forward_attention(self, inputs, alignment, query):
|
||||||
# forward attention
|
# forward attention
|
||||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(),
|
prev_alpha = F.pad(self.alpha[:, :-1].clone().to(inputs.device),
|
||||||
(1, 0, 0, 0)).to(inputs.device)
|
(1, 0, 0, 0))
|
||||||
# compute transition potentials
|
# compute transition potentials
|
||||||
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
|
alpha = ((1 - self.u) * self.alpha
|
||||||
self.u * prev_alpha) + 1e-8) * alignment
|
+ self.u * prev_alpha
|
||||||
|
+ 1e-8) * alignment
|
||||||
# force incremental alignment
|
# force incremental alignment
|
||||||
if not self.training and self.forward_attn_mask:
|
if not self.training and self.forward_attn_mask:
|
||||||
_, n = prev_alpha.max(1)
|
_, n = prev_alpha.max(1)
|
||||||
|
@ -231,19 +232,20 @@ class Attention(nn.Module):
|
||||||
self.u = torch.sigmoid(self.ta(ta_input))
|
self.u = torch.sigmoid(self.ta(ta_input))
|
||||||
return context, self.alpha
|
return context, self.alpha
|
||||||
|
|
||||||
def forward(self, attention_hidden_state, inputs, processed_inputs, mask):
|
def forward(self, query, inputs, processed_inputs, mask):
|
||||||
if self.location_attention:
|
if self.location_attention:
|
||||||
attention, processed_query = self.get_location_attention(
|
attention, processed_query = self.get_location_attention(
|
||||||
attention_hidden_state, processed_inputs)
|
query, processed_inputs)
|
||||||
else:
|
else:
|
||||||
attention, processed_query = self.get_attention(
|
attention, processed_query = self.get_attention(
|
||||||
attention_hidden_state, processed_inputs)
|
query, processed_inputs)
|
||||||
# apply masking
|
# apply masking
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
attention.data.masked_fill_(1 - mask, self._mask_value)
|
attention.data.masked_fill_(1 - mask, self._mask_value)
|
||||||
# apply windowing - only in eval mode
|
# apply windowing - only in eval mode
|
||||||
if not self.training and self.windowing:
|
if not self.training and self.windowing:
|
||||||
attention = self.apply_windowing(attention, inputs)
|
attention = self.apply_windowing(attention, inputs)
|
||||||
|
|
||||||
# normalize attention values
|
# normalize attention values
|
||||||
if self.norm == "softmax":
|
if self.norm == "softmax":
|
||||||
alignment = torch.softmax(attention, dim=-1)
|
alignment = torch.softmax(attention, dim=-1)
|
||||||
|
@ -258,7 +260,7 @@ class Attention(nn.Module):
|
||||||
# apply forward attention if enabled
|
# apply forward attention if enabled
|
||||||
if self.forward_attn:
|
if self.forward_attn:
|
||||||
context, self.attention_weights = self.apply_forward_attention(
|
context, self.attention_weights = self.apply_forward_attention(
|
||||||
inputs, alignment, attention_hidden_state)
|
inputs, alignment, query)
|
||||||
else:
|
else:
|
||||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||||
context = context.squeeze(1)
|
context = context.squeeze(1)
|
||||||
|
|
|
@ -283,6 +283,7 @@ class Decoder(nn.Module):
|
||||||
self.memory_size = memory_size if memory_size > 0 else r
|
self.memory_size = memory_size if memory_size > 0 else r
|
||||||
self.memory_dim = memory_dim
|
self.memory_dim = memory_dim
|
||||||
self.separate_stopnet = separate_stopnet
|
self.separate_stopnet = separate_stopnet
|
||||||
|
self.query_dim = 256
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
self.prenet = Prenet(
|
self.prenet = Prenet(
|
||||||
memory_dim * self.memory_size,
|
memory_dim * self.memory_size,
|
||||||
|
@ -290,18 +291,18 @@ class Decoder(nn.Module):
|
||||||
prenet_dropout,
|
prenet_dropout,
|
||||||
out_features=[256, 128])
|
out_features=[256, 128])
|
||||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||||
self.attention_rnn = nn.GRUCell(in_features + 128, 256)
|
self.query_rnn = nn.GRUCell(in_features + 128, self.query_dim)
|
||||||
self.attention_layer = Attention(attention_rnn_dim=256,
|
self.attention = Attention(query_dim=self.query_dim,
|
||||||
embedding_dim=in_features,
|
embedding_dim=in_features,
|
||||||
attention_dim=128,
|
attention_dim=128,
|
||||||
location_attention=location_attn,
|
location_attention=location_attn,
|
||||||
attention_location_n_filters=32,
|
attention_location_n_filters=32,
|
||||||
attention_location_kernel_size=31,
|
attention_location_kernel_size=31,
|
||||||
windowing=attn_windowing,
|
windowing=attn_windowing,
|
||||||
norm=attn_norm,
|
norm=attn_norm,
|
||||||
forward_attn=forward_attn,
|
forward_attn=forward_attn,
|
||||||
trans_agent=trans_agent,
|
trans_agent=trans_agent,
|
||||||
forward_attn_mask=forward_attn_mask)
|
forward_attn_mask=forward_attn_mask)
|
||||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||||
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||||
|
@ -310,7 +311,7 @@ class Decoder(nn.Module):
|
||||||
# RNN_state -> |Linear| -> mel_spec
|
# RNN_state -> |Linear| -> mel_spec
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
# learn init values instead of zero init.
|
# learn init values instead of zero init.
|
||||||
self.attention_rnn_init = nn.Embedding(1, 256)
|
self.query_rnn_init = nn.Embedding(1, 256)
|
||||||
self.memory_init = nn.Embedding(1, self.memory_size * memory_dim)
|
self.memory_init = nn.Embedding(1, self.memory_size * memory_dim)
|
||||||
self.decoder_rnn_inits = nn.Embedding(2, 256)
|
self.decoder_rnn_inits = nn.Embedding(2, 256)
|
||||||
self.stopnet = StopNet(256 + memory_dim * r)
|
self.stopnet = StopNet(256 + memory_dim * r)
|
||||||
|
@ -347,18 +348,18 @@ class Decoder(nn.Module):
|
||||||
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
|
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
|
||||||
|
|
||||||
# decoder states
|
# decoder states
|
||||||
self.attention_rnn_hidden = self.attention_rnn_init(
|
self.query = self.query_rnn_init(
|
||||||
inputs.data.new_zeros(B).long())
|
inputs.data.new_zeros(B).long())
|
||||||
self.decoder_rnn_hiddens = [
|
self.decoder_rnn_hiddens = [
|
||||||
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
|
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
|
||||||
for idx in range(len(self.decoder_rnns))
|
for idx in range(len(self.decoder_rnns))
|
||||||
]
|
]
|
||||||
self.current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
self.context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||||
# attention states
|
# attention states
|
||||||
self.attention = inputs.data.new(B, T).zero_()
|
self.attention = inputs.data.new(B, T).zero_()
|
||||||
self.attention_cum = inputs.data.new(B, T).zero_()
|
self.attention_cum = inputs.data.new(B, T).zero_()
|
||||||
# cache attention inputs
|
# cache attention inputs
|
||||||
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
|
self.processed_inputs = self.attention.inputs_layer(inputs)
|
||||||
|
|
||||||
def _parse_outputs(self, outputs, attentions, stop_tokens):
|
def _parse_outputs(self, outputs, attentions, stop_tokens):
|
||||||
# Back to batch first
|
# Back to batch first
|
||||||
|
@ -370,13 +371,15 @@ class Decoder(nn.Module):
|
||||||
def decode(self, inputs, mask=None):
|
def decode(self, inputs, mask=None):
|
||||||
# Prenet
|
# Prenet
|
||||||
processed_memory = self.prenet(self.memory_input)
|
processed_memory = self.prenet(self.memory_input)
|
||||||
|
|
||||||
# Attention RNN
|
# Attention RNN
|
||||||
self.attention_rnn_hidden = self.attention_rnn(torch.cat((processed_memory, self.current_context_vec), -1), self.attention_rnn_hidden)
|
self.query = self.query_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query)
|
||||||
self.current_context_vec = self.attention_layer(self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
self.context_vec = self.attention(self.query, inputs, self.processed_inputs, mask)
|
||||||
# Concat RNN output and attention context vector
|
|
||||||
|
# Concat query and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
torch.cat((self.attention_rnn_hidden, self.current_context_vec),
|
torch.cat((self.query, self.context_vec), -1))
|
||||||
-1))
|
|
||||||
# Pass through the decoder RNNs
|
# Pass through the decoder RNNs
|
||||||
for idx in range(len(self.decoder_rnns)):
|
for idx in range(len(self.decoder_rnns)):
|
||||||
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
|
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
|
||||||
|
@ -384,18 +387,18 @@ class Decoder(nn.Module):
|
||||||
# Residual connection
|
# Residual connection
|
||||||
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
|
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
|
||||||
decoder_output = decoder_input
|
decoder_output = decoder_input
|
||||||
del decoder_input
|
|
||||||
# predict mel vectors from decoder vectors
|
# predict mel vectors from decoder vectors
|
||||||
output = self.proj_to_mel(decoder_output)
|
output = self.proj_to_mel(decoder_output)
|
||||||
output = torch.sigmoid(output)
|
output = torch.sigmoid(output)
|
||||||
|
|
||||||
# predict stop token
|
# predict stop token
|
||||||
stopnet_input = torch.cat([decoder_output, output], -1)
|
stopnet_input = torch.cat([decoder_output, output], -1)
|
||||||
del decoder_output
|
|
||||||
if self.separate_stopnet:
|
if self.separate_stopnet:
|
||||||
stop_token = self.stopnet(stopnet_input.detach())
|
stop_token = self.stopnet(stopnet_input.detach())
|
||||||
else:
|
else:
|
||||||
stop_token = self.stopnet(stopnet_input)
|
stop_token = self.stopnet(stopnet_input)
|
||||||
return output, stop_token, self.attention_layer.attention_weights
|
return output, stop_token, self.attention.attention_weights
|
||||||
|
|
||||||
def _update_memory_queue(self, new_memory):
|
def _update_memory_queue(self, new_memory):
|
||||||
if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size:
|
if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size:
|
||||||
|
@ -427,7 +430,7 @@ class Decoder(nn.Module):
|
||||||
stop_tokens = []
|
stop_tokens = []
|
||||||
t = 0
|
t = 0
|
||||||
self._init_states(inputs)
|
self._init_states(inputs)
|
||||||
self.attention_layer.init_states(inputs)
|
self.attention.init_states(inputs)
|
||||||
while len(outputs) < memory.size(0):
|
while len(outputs) < memory.size(0):
|
||||||
if t > 0:
|
if t > 0:
|
||||||
new_memory = memory[t - 1]
|
new_memory = memory[t - 1]
|
||||||
|
@ -453,8 +456,8 @@ class Decoder(nn.Module):
|
||||||
stop_tokens = []
|
stop_tokens = []
|
||||||
t = 0
|
t = 0
|
||||||
self._init_states(inputs)
|
self._init_states(inputs)
|
||||||
self.attention_layer.init_win_idx()
|
self.attention.init_win_idx()
|
||||||
self.attention_layer.init_states(inputs)
|
self.attention.init_states(inputs)
|
||||||
while True:
|
while True:
|
||||||
if t > 0:
|
if t > 0:
|
||||||
new_memory = outputs[-1]
|
new_memory = outputs[-1]
|
||||||
|
|
|
@ -104,7 +104,7 @@ class Decoder(nn.Module):
|
||||||
self.r = r
|
self.r = r
|
||||||
self.encoder_embedding_dim = in_features
|
self.encoder_embedding_dim = in_features
|
||||||
self.separate_stopnet = separate_stopnet
|
self.separate_stopnet = separate_stopnet
|
||||||
self.attention_rnn_dim = 1024
|
self.query_dim = 1024
|
||||||
self.decoder_rnn_dim = 1024
|
self.decoder_rnn_dim = 1024
|
||||||
self.prenet_dim = 256
|
self.prenet_dim = 256
|
||||||
self.max_decoder_steps = 1000
|
self.max_decoder_steps = 1000
|
||||||
|
@ -116,22 +116,22 @@ class Decoder(nn.Module):
|
||||||
prenet_dropout,
|
prenet_dropout,
|
||||||
[self.prenet_dim, self.prenet_dim], bias=False)
|
[self.prenet_dim, self.prenet_dim], bias=False)
|
||||||
|
|
||||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
self.query_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||||
self.attention_rnn_dim)
|
self.query_dim)
|
||||||
|
|
||||||
self.attention_layer = Attention(attention_rnn_dim=self.attention_rnn_dim,
|
self.attention = Attention(query_dim=self.query_dim,
|
||||||
embedding_dim=in_features,
|
embedding_dim=in_features,
|
||||||
attention_dim=128,
|
attention_dim=128,
|
||||||
location_attention=location_attn,
|
location_attention=location_attn,
|
||||||
attention_location_n_filters=32,
|
attention_location_n_filters=32,
|
||||||
attention_location_kernel_size=31,
|
attention_location_kernel_size=31,
|
||||||
windowing=attn_win,
|
windowing=attn_win,
|
||||||
norm=attn_norm,
|
norm=attn_norm,
|
||||||
forward_attn=forward_attn,
|
forward_attn=forward_attn,
|
||||||
trans_agent=trans_agent,
|
trans_agent=trans_agent,
|
||||||
forward_attn_mask=forward_attn_mask)
|
forward_attn_mask=forward_attn_mask)
|
||||||
|
|
||||||
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
|
self.decoder_rnn = nn.LSTMCell(self.query_dim + in_features,
|
||||||
self.decoder_rnn_dim, 1)
|
self.decoder_rnn_dim, 1)
|
||||||
|
|
||||||
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
|
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
|
||||||
|
@ -145,7 +145,7 @@ class Decoder(nn.Module):
|
||||||
bias=True,
|
bias=True,
|
||||||
init_gain='sigmoid'))
|
init_gain='sigmoid'))
|
||||||
|
|
||||||
self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim)
|
self.query_rnn_init = nn.Embedding(1, self.query_dim)
|
||||||
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
|
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
|
||||||
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
|
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
|
||||||
self.memory_truncated = None
|
self.memory_truncated = None
|
||||||
|
@ -160,10 +160,10 @@ class Decoder(nn.Module):
|
||||||
# T = inputs.size(1)
|
# T = inputs.size(1)
|
||||||
|
|
||||||
if not keep_states:
|
if not keep_states:
|
||||||
self.attention_hidden = self.attention_rnn_init(
|
self.query = self.query_rnn_init(
|
||||||
inputs.data.new_zeros(B).long())
|
inputs.data.new_zeros(B).long())
|
||||||
self.attention_cell = Variable(
|
self.query_rnn_cell_state = Variable(
|
||||||
inputs.data.new(B, self.attention_rnn_dim).zero_())
|
inputs.data.new(B, self.query_dim).zero_())
|
||||||
|
|
||||||
self.decoder_hidden = self.decoder_rnn_inits(
|
self.decoder_hidden = self.decoder_rnn_inits(
|
||||||
inputs.data.new_zeros(B).long())
|
inputs.data.new_zeros(B).long())
|
||||||
|
@ -174,7 +174,7 @@ class Decoder(nn.Module):
|
||||||
inputs.data.new(B, self.encoder_embedding_dim).zero_())
|
inputs.data.new(B, self.encoder_embedding_dim).zero_())
|
||||||
|
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
|
self.processed_inputs = self.attention.inputs_layer(inputs)
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
|
|
||||||
def _reshape_memory(self, memories):
|
def _reshape_memory(self, memories):
|
||||||
|
@ -193,18 +193,18 @@ class Decoder(nn.Module):
|
||||||
return outputs, stop_tokens, alignments
|
return outputs, stop_tokens, alignments
|
||||||
|
|
||||||
def decode(self, memory):
|
def decode(self, memory):
|
||||||
cell_input = torch.cat((memory, self.context), -1)
|
query_input = torch.cat((memory, self.context), -1)
|
||||||
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
self.query, self.query_rnn_cell_state = self.query_rnn(
|
||||||
cell_input, (self.attention_hidden, self.attention_cell))
|
query_input, (self.query, self.query_rnn_cell_state))
|
||||||
self.attention_hidden = F.dropout(
|
self.query = F.dropout(
|
||||||
self.attention_hidden, self.p_attention_dropout, self.training)
|
self.query, self.p_attention_dropout, self.training)
|
||||||
self.attention_cell = F.dropout(
|
self.query_rnn_cell_state = F.dropout(
|
||||||
self.attention_cell, self.p_attention_dropout, self.training)
|
self.query_rnn_cell_state, self.p_attention_dropout, self.training)
|
||||||
|
|
||||||
self.context = self.attention_layer(self.attention_hidden, self.inputs,
|
self.context = self.attention(self.query, self.inputs,
|
||||||
self.processed_inputs, self.mask)
|
self.processed_inputs, self.mask)
|
||||||
|
|
||||||
memory = torch.cat((self.attention_hidden, self.context), -1)
|
memory = torch.cat((self.query, self.context), -1)
|
||||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
||||||
memory, (self.decoder_hidden, self.decoder_cell))
|
memory, (self.decoder_hidden, self.decoder_cell))
|
||||||
self.decoder_hidden = F.dropout(self.decoder_hidden,
|
self.decoder_hidden = F.dropout(self.decoder_hidden,
|
||||||
|
@ -223,7 +223,7 @@ class Decoder(nn.Module):
|
||||||
stop_token = self.stopnet(stopnet_input.detach())
|
stop_token = self.stopnet(stopnet_input.detach())
|
||||||
else:
|
else:
|
||||||
stop_token = self.stopnet(stopnet_input)
|
stop_token = self.stopnet(stopnet_input)
|
||||||
return decoder_output, stop_token, self.attention_layer.attention_weights
|
return decoder_output, stop_token, self.attention.attention_weights
|
||||||
|
|
||||||
def forward(self, inputs, memories, mask):
|
def forward(self, inputs, memories, mask):
|
||||||
memory = self.get_go_frame(inputs).unsqueeze(0)
|
memory = self.get_go_frame(inputs).unsqueeze(0)
|
||||||
|
@ -232,7 +232,7 @@ class Decoder(nn.Module):
|
||||||
memories = self.prenet(memories)
|
memories = self.prenet(memories)
|
||||||
|
|
||||||
self._init_states(inputs, mask=mask)
|
self._init_states(inputs, mask=mask)
|
||||||
self.attention_layer.init_states(inputs)
|
self.attention.init_states(inputs)
|
||||||
|
|
||||||
outputs, stop_tokens, alignments = [], [], []
|
outputs, stop_tokens, alignments = [], [], []
|
||||||
while len(outputs) < memories.size(0) - 1:
|
while len(outputs) < memories.size(0) - 1:
|
||||||
|
@ -251,8 +251,8 @@ class Decoder(nn.Module):
|
||||||
memory = self.get_go_frame(inputs)
|
memory = self.get_go_frame(inputs)
|
||||||
self._init_states(inputs, mask=None)
|
self._init_states(inputs, mask=None)
|
||||||
|
|
||||||
self.attention_layer.init_win_idx()
|
self.attention.init_win_idx()
|
||||||
self.attention_layer.init_states(inputs)
|
self.attention.init_states(inputs)
|
||||||
|
|
||||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||||
stop_flags = [True, False, False]
|
stop_flags = [True, False, False]
|
||||||
|
@ -295,8 +295,8 @@ class Decoder(nn.Module):
|
||||||
else:
|
else:
|
||||||
self._init_states(inputs, mask=None, keep_states=True)
|
self._init_states(inputs, mask=None, keep_states=True)
|
||||||
|
|
||||||
self.attention_layer.init_win_idx()
|
self.attention.init_win_idx()
|
||||||
self.attention_layer.init_states(inputs)
|
self.attention.init_states(inputs)
|
||||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||||
stop_flags = [True, False, False]
|
stop_flags = [True, False, False]
|
||||||
stop_count = 0
|
stop_count = 0
|
||||||
|
|
Loading…
Reference in New Issue