From 4826e7db9c970fec96b5f1ffbb908dcaa6df4ce0 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 18 Dec 2018 01:30:30 +0100 Subject: [PATCH] remove intermediate tensor asap --- layers/attention.py | 10 ++++------ layers/tacotron.py | 7 ++++++- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/layers/attention.py b/layers/attention.py index 534e4ba4..ea31768a 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -86,15 +86,15 @@ class LocationSensitiveAttention(nn.Module): if query.dim() == 2: # insert time-axis for broadcasting query = query.unsqueeze(1) - loc_conv = self.loc_conv(loc) - loc_conv = loc_conv.transpose(1, 2) - processed_loc = self.loc_linear(loc_conv) + processed_loc = self.loc_linear(self.loc_conv(loc).transpose(1, 2)) processed_query = self.query_layer(query) # cache annots if self.processed_annots is None: self.processed_annots = self.annot_layer(annot) alignment = self.v( torch.tanh(processed_query + self.processed_annots + processed_loc)) + del processed_loc + del processed_query # (batch, max_time) return alignment.squeeze(-1) @@ -138,11 +138,9 @@ class AttentionRNNCell(nn.Module): """ if t == 0: self.alignment_model.reset() - # Concat input query and previous context context - rnn_input = torch.cat((memory, context), -1) # Feed it to RNN # s_i = f(y_{i-1}, c_{i}, s_{i-1}) - rnn_output = self.rnn_cell(rnn_input, rnn_state) + rnn_output = self.rnn_cell(torch.cat((memory, context), -1), rnn_state) # Alignment # (batch, max_time) # e_{ij} = a(s_{i-1}, h_j) diff --git a/layers/tacotron.py b/layers/tacotron.py index e9d7d7ce..6b6cadb0 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -403,6 +403,7 @@ class Decoder(nn.Module): attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_cat, mask, t) + del attention_cat attention_cum += attention # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( @@ -414,15 +415,19 @@ class Decoder(nn.Module): # Residual connectinon decoder_input = decoder_rnn_hiddens[idx] + decoder_input decoder_output = decoder_input + del decoder_input # predict mel vectors from decoder vectors output = self.proj_to_mel(decoder_output) output = torch.sigmoid(output) # predict stop token - stopnet_input = torch.cat([decoder_input, output], -1) + stopnet_input = torch.cat([decoder_output, output], -1) + del decoder_output stop_token = self.stopnet(stopnet_input) + del stopnet_input outputs += [output] attentions += [attention] stop_tokens += [stop_token] + del output t += 1 if memory is not None: if t >= T_decoder: