remove intermediate tensor asap

This commit is contained in:
Eren Golge 2018-12-18 01:30:30 +01:00
parent b6f559d315
commit 4826e7db9c
2 changed files with 10 additions and 7 deletions

View File

@ -86,15 +86,15 @@ class LocationSensitiveAttention(nn.Module):
if query.dim() == 2:
# insert time-axis for broadcasting
query = query.unsqueeze(1)
loc_conv = self.loc_conv(loc)
loc_conv = loc_conv.transpose(1, 2)
processed_loc = self.loc_linear(loc_conv)
processed_loc = self.loc_linear(self.loc_conv(loc).transpose(1, 2))
processed_query = self.query_layer(query)
# cache annots
if self.processed_annots is None:
self.processed_annots = self.annot_layer(annot)
alignment = self.v(
torch.tanh(processed_query + self.processed_annots + processed_loc))
del processed_loc
del processed_query
# (batch, max_time)
return alignment.squeeze(-1)
@ -138,11 +138,9 @@ class AttentionRNNCell(nn.Module):
"""
if t == 0:
self.alignment_model.reset()
# Concat input query and previous context context
rnn_input = torch.cat((memory, context), -1)
# Feed it to RNN
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
rnn_output = self.rnn_cell(rnn_input, rnn_state)
rnn_output = self.rnn_cell(torch.cat((memory, context), -1), rnn_state)
# Alignment
# (batch, max_time)
# e_{ij} = a(s_{i-1}, h_j)

View File

@ -403,6 +403,7 @@ class Decoder(nn.Module):
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention_cat, mask, t)
del attention_cat
attention_cum += attention
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
@ -414,15 +415,19 @@ class Decoder(nn.Module):
# Residual connectinon
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
decoder_output = decoder_input
del decoder_input
# predict mel vectors from decoder vectors
output = self.proj_to_mel(decoder_output)
output = torch.sigmoid(output)
# predict stop token
stopnet_input = torch.cat([decoder_input, output], -1)
stopnet_input = torch.cat([decoder_output, output], -1)
del decoder_output
stop_token = self.stopnet(stopnet_input)
del stopnet_input
outputs += [output]
attentions += [attention]
stop_tokens += [stop_token]
del output
t += 1
if memory is not None:
if t >= T_decoder: