mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #245 from twerkmeister/attention_naming
renamed attention_rnn to query_rnn
This commit is contained in:
commit
69abea55d6
|
@ -108,19 +108,19 @@ class LocationLayer(nn.Module):
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
# Pylint gets confused by PyTorch conventions here
|
# Pylint gets confused by PyTorch conventions here
|
||||||
#pylint: disable=attribute-defined-outside-init
|
#pylint: disable=attribute-defined-outside-init
|
||||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
def __init__(self, query_dim, embedding_dim, attention_dim,
|
||||||
location_attention, attention_location_n_filters,
|
location_attention, attention_location_n_filters,
|
||||||
attention_location_kernel_size, windowing, norm, forward_attn,
|
attention_location_kernel_size, windowing, norm, forward_attn,
|
||||||
trans_agent, forward_attn_mask):
|
trans_agent, forward_attn_mask):
|
||||||
super(Attention, self).__init__()
|
super(Attention, self).__init__()
|
||||||
self.query_layer = Linear(
|
self.query_layer = Linear(
|
||||||
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
query_dim, attention_dim, bias=False, init_gain='tanh')
|
||||||
self.inputs_layer = Linear(
|
self.inputs_layer = Linear(
|
||||||
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
||||||
self.v = Linear(attention_dim, 1, bias=True)
|
self.v = Linear(attention_dim, 1, bias=True)
|
||||||
if trans_agent:
|
if trans_agent:
|
||||||
self.ta = nn.Linear(
|
self.ta = nn.Linear(
|
||||||
attention_rnn_dim + embedding_dim, 1, bias=True)
|
query_dim + embedding_dim, 1, bias=True)
|
||||||
if location_attention:
|
if location_attention:
|
||||||
self.location_layer = LocationLayer(
|
self.location_layer = LocationLayer(
|
||||||
attention_dim,
|
attention_dim,
|
||||||
|
@ -201,16 +201,17 @@ class Attention(nn.Module):
|
||||||
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
||||||
return attention
|
return attention
|
||||||
|
|
||||||
def apply_forward_attention(self, inputs, alignment, query):
|
def apply_forward_attention(self, alignment):
|
||||||
# forward attention
|
# forward attention
|
||||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(),
|
fwd_shifted_alpha = F.pad(self.alpha[:, :-1].clone().to(alignment.device),
|
||||||
(1, 0, 0, 0)).to(inputs.device)
|
(1, 0, 0, 0))
|
||||||
# compute transition potentials
|
# compute transition potentials
|
||||||
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
|
alpha = ((1 - self.u) * self.alpha
|
||||||
self.u * prev_alpha) + 1e-8) * alignment
|
+ self.u * fwd_shifted_alpha
|
||||||
|
+ 1e-8) * alignment
|
||||||
# force incremental alignment
|
# force incremental alignment
|
||||||
if not self.training and self.forward_attn_mask:
|
if not self.training and self.forward_attn_mask:
|
||||||
_, n = prev_alpha.max(1)
|
_, n = fwd_shifted_alpha.max(1)
|
||||||
val, n2 = alpha.max(1)
|
val, n2 = alpha.max(1)
|
||||||
for b in range(alignment.shape[0]):
|
for b in range(alignment.shape[0]):
|
||||||
alpha[b, n[b] + 3:] = 0
|
alpha[b, n[b] + 3:] = 0
|
||||||
|
@ -220,30 +221,24 @@ class Attention(nn.Module):
|
||||||
alpha[b,
|
alpha[b,
|
||||||
(n[b] - 2
|
(n[b] - 2
|
||||||
)] = 0.01 * val[b] # smoothing factor for the prev step
|
)] = 0.01 * val[b] # smoothing factor for the prev step
|
||||||
# compute attention weights
|
# renormalize attention weights
|
||||||
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
alpha = alpha / alpha.sum(dim=1, keepdim=True)
|
||||||
# compute context
|
return alpha
|
||||||
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
|
|
||||||
context = context.squeeze(1)
|
|
||||||
# compute transition agent
|
|
||||||
if self.trans_agent:
|
|
||||||
ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
|
|
||||||
self.u = torch.sigmoid(self.ta(ta_input))
|
|
||||||
return context, self.alpha
|
|
||||||
|
|
||||||
def forward(self, attention_hidden_state, inputs, processed_inputs, mask):
|
def forward(self, query, inputs, processed_inputs, mask):
|
||||||
if self.location_attention:
|
if self.location_attention:
|
||||||
attention, processed_query = self.get_location_attention(
|
attention, _ = self.get_location_attention(
|
||||||
attention_hidden_state, processed_inputs)
|
query, processed_inputs)
|
||||||
else:
|
else:
|
||||||
attention, processed_query = self.get_attention(
|
attention, _ = self.get_attention(
|
||||||
attention_hidden_state, processed_inputs)
|
query, processed_inputs)
|
||||||
# apply masking
|
# apply masking
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
attention.data.masked_fill_(1 - mask, self._mask_value)
|
attention.data.masked_fill_(1 - mask, self._mask_value)
|
||||||
# apply windowing - only in eval mode
|
# apply windowing - only in eval mode
|
||||||
if not self.training and self.windowing:
|
if not self.training and self.windowing:
|
||||||
attention = self.apply_windowing(attention, inputs)
|
attention = self.apply_windowing(attention, inputs)
|
||||||
|
|
||||||
# normalize attention values
|
# normalize attention values
|
||||||
if self.norm == "softmax":
|
if self.norm == "softmax":
|
||||||
alignment = torch.softmax(attention, dim=-1)
|
alignment = torch.softmax(attention, dim=-1)
|
||||||
|
@ -252,15 +247,22 @@ class Attention(nn.Module):
|
||||||
attention).sum(
|
attention).sum(
|
||||||
dim=1, keepdim=True)
|
dim=1, keepdim=True)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown value for attention norm type")
|
raise ValueError("Unknown value for attention norm type")
|
||||||
|
|
||||||
if self.location_attention:
|
if self.location_attention:
|
||||||
self.update_location_attention(alignment)
|
self.update_location_attention(alignment)
|
||||||
|
|
||||||
# apply forward attention if enabled
|
# apply forward attention if enabled
|
||||||
if self.forward_attn:
|
if self.forward_attn:
|
||||||
context, self.attention_weights = self.apply_forward_attention(
|
alignment = self.apply_forward_attention(alignment)
|
||||||
inputs, alignment, attention_hidden_state)
|
self.alpha = alignment
|
||||||
else:
|
|
||||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||||
context = context.squeeze(1)
|
context = context.squeeze(1)
|
||||||
self.attention_weights = alignment
|
self.attention_weights = alignment
|
||||||
|
|
||||||
|
# compute transition agent
|
||||||
|
if self.forward_attn and self.trans_agent:
|
||||||
|
ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
|
||||||
|
self.u = torch.sigmoid(self.ta(ta_input))
|
||||||
return context
|
return context
|
||||||
|
|
|
@ -283,6 +283,7 @@ class Decoder(nn.Module):
|
||||||
self.memory_size = memory_size if memory_size > 0 else r
|
self.memory_size = memory_size if memory_size > 0 else r
|
||||||
self.memory_dim = memory_dim
|
self.memory_dim = memory_dim
|
||||||
self.separate_stopnet = separate_stopnet
|
self.separate_stopnet = separate_stopnet
|
||||||
|
self.query_dim = 256
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
self.prenet = Prenet(
|
self.prenet = Prenet(
|
||||||
memory_dim * self.memory_size,
|
memory_dim * self.memory_size,
|
||||||
|
@ -290,18 +291,20 @@ class Decoder(nn.Module):
|
||||||
prenet_dropout,
|
prenet_dropout,
|
||||||
out_features=[256, 128])
|
out_features=[256, 128])
|
||||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||||
self.attention_rnn = nn.GRUCell(in_features + 128, 256)
|
# attention_rnn generates queries for the attention mechanism
|
||||||
self.attention_layer = Attention(attention_rnn_dim=256,
|
self.attention_rnn = nn.GRUCell(in_features + 128, self.query_dim)
|
||||||
embedding_dim=in_features,
|
|
||||||
attention_dim=128,
|
self.attention = Attention(query_dim=self.query_dim,
|
||||||
location_attention=location_attn,
|
embedding_dim=in_features,
|
||||||
attention_location_n_filters=32,
|
attention_dim=128,
|
||||||
attention_location_kernel_size=31,
|
location_attention=location_attn,
|
||||||
windowing=attn_windowing,
|
attention_location_n_filters=32,
|
||||||
norm=attn_norm,
|
attention_location_kernel_size=31,
|
||||||
forward_attn=forward_attn,
|
windowing=attn_windowing,
|
||||||
trans_agent=trans_agent,
|
norm=attn_norm,
|
||||||
forward_attn_mask=forward_attn_mask)
|
forward_attn=forward_attn,
|
||||||
|
trans_agent=trans_agent,
|
||||||
|
forward_attn_mask=forward_attn_mask)
|
||||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||||
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||||
|
@ -347,18 +350,15 @@ class Decoder(nn.Module):
|
||||||
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
|
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
|
||||||
|
|
||||||
# decoder states
|
# decoder states
|
||||||
self.attention_rnn_hidden = self.attention_rnn_init(
|
self.query = self.attention_rnn_init(
|
||||||
inputs.data.new_zeros(B).long())
|
inputs.data.new_zeros(B).long())
|
||||||
self.decoder_rnn_hiddens = [
|
self.decoder_rnn_hiddens = [
|
||||||
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
|
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
|
||||||
for idx in range(len(self.decoder_rnns))
|
for idx in range(len(self.decoder_rnns))
|
||||||
]
|
]
|
||||||
self.current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
self.context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||||
# attention states
|
|
||||||
self.attention = inputs.data.new(B, T).zero_()
|
|
||||||
self.attention_cum = inputs.data.new(B, T).zero_()
|
|
||||||
# cache attention inputs
|
# cache attention inputs
|
||||||
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
|
self.processed_inputs = self.attention.inputs_layer(inputs)
|
||||||
|
|
||||||
def _parse_outputs(self, outputs, attentions, stop_tokens):
|
def _parse_outputs(self, outputs, attentions, stop_tokens):
|
||||||
# Back to batch first
|
# Back to batch first
|
||||||
|
@ -370,13 +370,15 @@ class Decoder(nn.Module):
|
||||||
def decode(self, inputs, mask=None):
|
def decode(self, inputs, mask=None):
|
||||||
# Prenet
|
# Prenet
|
||||||
processed_memory = self.prenet(self.memory_input)
|
processed_memory = self.prenet(self.memory_input)
|
||||||
# Attention RNN
|
|
||||||
self.attention_rnn_hidden = self.attention_rnn(torch.cat((processed_memory, self.current_context_vec), -1), self.attention_rnn_hidden)
|
# Attention
|
||||||
self.current_context_vec = self.attention_layer(self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
self.query = self.attention_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query)
|
||||||
# Concat RNN output and attention context vector
|
self.context_vec = self.attention(self.query, inputs, self.processed_inputs, mask)
|
||||||
|
|
||||||
|
# Concat query and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
torch.cat((self.attention_rnn_hidden, self.current_context_vec),
|
torch.cat((self.query, self.context_vec), -1))
|
||||||
-1))
|
|
||||||
# Pass through the decoder RNNs
|
# Pass through the decoder RNNs
|
||||||
for idx in range(len(self.decoder_rnns)):
|
for idx in range(len(self.decoder_rnns)):
|
||||||
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
|
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
|
||||||
|
@ -384,18 +386,18 @@ class Decoder(nn.Module):
|
||||||
# Residual connection
|
# Residual connection
|
||||||
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
|
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
|
||||||
decoder_output = decoder_input
|
decoder_output = decoder_input
|
||||||
del decoder_input
|
|
||||||
# predict mel vectors from decoder vectors
|
# predict mel vectors from decoder vectors
|
||||||
output = self.proj_to_mel(decoder_output)
|
output = self.proj_to_mel(decoder_output)
|
||||||
output = torch.sigmoid(output)
|
output = torch.sigmoid(output)
|
||||||
|
|
||||||
# predict stop token
|
# predict stop token
|
||||||
stopnet_input = torch.cat([decoder_output, output], -1)
|
stopnet_input = torch.cat([decoder_output, output], -1)
|
||||||
del decoder_output
|
|
||||||
if self.separate_stopnet:
|
if self.separate_stopnet:
|
||||||
stop_token = self.stopnet(stopnet_input.detach())
|
stop_token = self.stopnet(stopnet_input.detach())
|
||||||
else:
|
else:
|
||||||
stop_token = self.stopnet(stopnet_input)
|
stop_token = self.stopnet(stopnet_input)
|
||||||
return output, stop_token, self.attention_layer.attention_weights
|
return output, stop_token, self.attention.attention_weights
|
||||||
|
|
||||||
def _update_memory_queue(self, new_memory):
|
def _update_memory_queue(self, new_memory):
|
||||||
if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size:
|
if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size:
|
||||||
|
@ -427,7 +429,7 @@ class Decoder(nn.Module):
|
||||||
stop_tokens = []
|
stop_tokens = []
|
||||||
t = 0
|
t = 0
|
||||||
self._init_states(inputs)
|
self._init_states(inputs)
|
||||||
self.attention_layer.init_states(inputs)
|
self.attention.init_states(inputs)
|
||||||
while len(outputs) < memory.size(0):
|
while len(outputs) < memory.size(0):
|
||||||
if t > 0:
|
if t > 0:
|
||||||
new_memory = memory[t - 1]
|
new_memory = memory[t - 1]
|
||||||
|
@ -453,8 +455,8 @@ class Decoder(nn.Module):
|
||||||
stop_tokens = []
|
stop_tokens = []
|
||||||
t = 0
|
t = 0
|
||||||
self._init_states(inputs)
|
self._init_states(inputs)
|
||||||
self.attention_layer.init_win_idx()
|
self.attention.init_win_idx()
|
||||||
self.attention_layer.init_states(inputs)
|
self.attention.init_states(inputs)
|
||||||
while True:
|
while True:
|
||||||
if t > 0:
|
if t > 0:
|
||||||
new_memory = outputs[-1]
|
new_memory = outputs[-1]
|
||||||
|
|
|
@ -104,7 +104,7 @@ class Decoder(nn.Module):
|
||||||
self.r = r
|
self.r = r
|
||||||
self.encoder_embedding_dim = in_features
|
self.encoder_embedding_dim = in_features
|
||||||
self.separate_stopnet = separate_stopnet
|
self.separate_stopnet = separate_stopnet
|
||||||
self.attention_rnn_dim = 1024
|
self.query_dim = 1024
|
||||||
self.decoder_rnn_dim = 1024
|
self.decoder_rnn_dim = 1024
|
||||||
self.prenet_dim = 256
|
self.prenet_dim = 256
|
||||||
self.max_decoder_steps = 1000
|
self.max_decoder_steps = 1000
|
||||||
|
@ -117,21 +117,21 @@ class Decoder(nn.Module):
|
||||||
[self.prenet_dim, self.prenet_dim], bias=False)
|
[self.prenet_dim, self.prenet_dim], bias=False)
|
||||||
|
|
||||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||||
self.attention_rnn_dim)
|
self.query_dim)
|
||||||
|
|
||||||
self.attention_layer = Attention(attention_rnn_dim=self.attention_rnn_dim,
|
self.attention = Attention(query_dim=self.query_dim,
|
||||||
embedding_dim=in_features,
|
embedding_dim=in_features,
|
||||||
attention_dim=128,
|
attention_dim=128,
|
||||||
location_attention=location_attn,
|
location_attention=location_attn,
|
||||||
attention_location_n_filters=32,
|
attention_location_n_filters=32,
|
||||||
attention_location_kernel_size=31,
|
attention_location_kernel_size=31,
|
||||||
windowing=attn_win,
|
windowing=attn_win,
|
||||||
norm=attn_norm,
|
norm=attn_norm,
|
||||||
forward_attn=forward_attn,
|
forward_attn=forward_attn,
|
||||||
trans_agent=trans_agent,
|
trans_agent=trans_agent,
|
||||||
forward_attn_mask=forward_attn_mask)
|
forward_attn_mask=forward_attn_mask)
|
||||||
|
|
||||||
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
|
self.decoder_rnn = nn.LSTMCell(self.query_dim + in_features,
|
||||||
self.decoder_rnn_dim, 1)
|
self.decoder_rnn_dim, 1)
|
||||||
|
|
||||||
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
|
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
|
||||||
|
@ -145,7 +145,7 @@ class Decoder(nn.Module):
|
||||||
bias=True,
|
bias=True,
|
||||||
init_gain='sigmoid'))
|
init_gain='sigmoid'))
|
||||||
|
|
||||||
self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim)
|
self.attention_rnn_init = nn.Embedding(1, self.query_dim)
|
||||||
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
|
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
|
||||||
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
|
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
|
||||||
self.memory_truncated = None
|
self.memory_truncated = None
|
||||||
|
@ -160,10 +160,10 @@ class Decoder(nn.Module):
|
||||||
# T = inputs.size(1)
|
# T = inputs.size(1)
|
||||||
|
|
||||||
if not keep_states:
|
if not keep_states:
|
||||||
self.attention_hidden = self.attention_rnn_init(
|
self.query = self.attention_rnn_init(
|
||||||
inputs.data.new_zeros(B).long())
|
inputs.data.new_zeros(B).long())
|
||||||
self.attention_cell = Variable(
|
self.attention_rnn_cell_state = Variable(
|
||||||
inputs.data.new(B, self.attention_rnn_dim).zero_())
|
inputs.data.new(B, self.query_dim).zero_())
|
||||||
|
|
||||||
self.decoder_hidden = self.decoder_rnn_inits(
|
self.decoder_hidden = self.decoder_rnn_inits(
|
||||||
inputs.data.new_zeros(B).long())
|
inputs.data.new_zeros(B).long())
|
||||||
|
@ -174,7 +174,7 @@ class Decoder(nn.Module):
|
||||||
inputs.data.new(B, self.encoder_embedding_dim).zero_())
|
inputs.data.new(B, self.encoder_embedding_dim).zero_())
|
||||||
|
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
|
self.processed_inputs = self.attention.inputs_layer(inputs)
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
|
|
||||||
def _reshape_memory(self, memories):
|
def _reshape_memory(self, memories):
|
||||||
|
@ -193,18 +193,18 @@ class Decoder(nn.Module):
|
||||||
return outputs, stop_tokens, alignments
|
return outputs, stop_tokens, alignments
|
||||||
|
|
||||||
def decode(self, memory):
|
def decode(self, memory):
|
||||||
cell_input = torch.cat((memory, self.context), -1)
|
query_input = torch.cat((memory, self.context), -1)
|
||||||
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
self.query, self.attention_rnn_cell_state = self.attention_rnn(
|
||||||
cell_input, (self.attention_hidden, self.attention_cell))
|
query_input, (self.query, self.attention_rnn_cell_state))
|
||||||
self.attention_hidden = F.dropout(
|
self.query = F.dropout(
|
||||||
self.attention_hidden, self.p_attention_dropout, self.training)
|
self.query, self.p_attention_dropout, self.training)
|
||||||
self.attention_cell = F.dropout(
|
self.attention_rnn_cell_state = F.dropout(
|
||||||
self.attention_cell, self.p_attention_dropout, self.training)
|
self.attention_rnn_cell_state, self.p_attention_dropout, self.training)
|
||||||
|
|
||||||
self.context = self.attention_layer(self.attention_hidden, self.inputs,
|
self.context = self.attention(self.query, self.inputs,
|
||||||
self.processed_inputs, self.mask)
|
self.processed_inputs, self.mask)
|
||||||
|
|
||||||
memory = torch.cat((self.attention_hidden, self.context), -1)
|
memory = torch.cat((self.query, self.context), -1)
|
||||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
||||||
memory, (self.decoder_hidden, self.decoder_cell))
|
memory, (self.decoder_hidden, self.decoder_cell))
|
||||||
self.decoder_hidden = F.dropout(self.decoder_hidden,
|
self.decoder_hidden = F.dropout(self.decoder_hidden,
|
||||||
|
@ -223,7 +223,7 @@ class Decoder(nn.Module):
|
||||||
stop_token = self.stopnet(stopnet_input.detach())
|
stop_token = self.stopnet(stopnet_input.detach())
|
||||||
else:
|
else:
|
||||||
stop_token = self.stopnet(stopnet_input)
|
stop_token = self.stopnet(stopnet_input)
|
||||||
return decoder_output, stop_token, self.attention_layer.attention_weights
|
return decoder_output, stop_token, self.attention.attention_weights
|
||||||
|
|
||||||
def forward(self, inputs, memories, mask):
|
def forward(self, inputs, memories, mask):
|
||||||
memory = self.get_go_frame(inputs).unsqueeze(0)
|
memory = self.get_go_frame(inputs).unsqueeze(0)
|
||||||
|
@ -232,7 +232,7 @@ class Decoder(nn.Module):
|
||||||
memories = self.prenet(memories)
|
memories = self.prenet(memories)
|
||||||
|
|
||||||
self._init_states(inputs, mask=mask)
|
self._init_states(inputs, mask=mask)
|
||||||
self.attention_layer.init_states(inputs)
|
self.attention.init_states(inputs)
|
||||||
|
|
||||||
outputs, stop_tokens, alignments = [], [], []
|
outputs, stop_tokens, alignments = [], [], []
|
||||||
while len(outputs) < memories.size(0) - 1:
|
while len(outputs) < memories.size(0) - 1:
|
||||||
|
@ -251,8 +251,8 @@ class Decoder(nn.Module):
|
||||||
memory = self.get_go_frame(inputs)
|
memory = self.get_go_frame(inputs)
|
||||||
self._init_states(inputs, mask=None)
|
self._init_states(inputs, mask=None)
|
||||||
|
|
||||||
self.attention_layer.init_win_idx()
|
self.attention.init_win_idx()
|
||||||
self.attention_layer.init_states(inputs)
|
self.attention.init_states(inputs)
|
||||||
|
|
||||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||||
stop_flags = [True, False, False]
|
stop_flags = [True, False, False]
|
||||||
|
@ -295,8 +295,8 @@ class Decoder(nn.Module):
|
||||||
else:
|
else:
|
||||||
self._init_states(inputs, mask=None, keep_states=True)
|
self._init_states(inputs, mask=None, keep_states=True)
|
||||||
|
|
||||||
self.attention_layer.init_win_idx()
|
self.attention.init_win_idx()
|
||||||
self.attention_layer.init_states(inputs)
|
self.attention.init_states(inputs)
|
||||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||||
stop_flags = [True, False, False]
|
stop_flags = [True, False, False]
|
||||||
stop_count = 0
|
stop_count = 0
|
||||||
|
|
Loading…
Reference in New Issue