From cdaaff9dbbc6fe89334be2d2c6e4000db116565b Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 28 Nov 2018 16:31:29 +0100 Subject: [PATCH] Modularize memory reshaping in decoder layer --- layers/tacotron.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/layers/tacotron.py b/layers/tacotron.py index e29d0d0a..f2077106 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -337,6 +337,17 @@ class Decoder(nn.Module): self.proj_to_mel.weight, gain=torch.nn.init.calculate_gain('linear')) + def _reshape_memory(self, memory): + B = memory.shape[0] + if memory is not None: + # Grouping multiple frames if necessary + if memory.size(-1) == self.memory_dim: + memory = memory.contiguous() + memory = memory.view(B, memory.size(1) // self.r, -1) + # Time first (T_decoder, B, memory_dim) + memory = memory.transpose(0, 1) + return memory + def forward(self, inputs, memory=None, mask=None): """ Decoder forward step. @@ -359,14 +370,8 @@ class Decoder(nn.Module): T = inputs.size(1) # Run greedy decoding if memory is None greedy = not self.training - if memory is not None: - # Grouping multiple frames if necessary - if memory.size(-1) == self.memory_dim: - memory = memory.contiguous() - memory = memory.view(B, memory.size(1) // self.r, -1) - " !! Dimension mismatch {} vs {} * {}".format( - memory.size(-1), self.memory_dim, self.r) - T_decoder = memory.size(1) + memory = self._reshape_memory(memory) + T_decoder = memory.size(0) # go frame as zeros matrix initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_() # decoder states @@ -376,14 +381,9 @@ class Decoder(nn.Module): for _ in range(len(self.decoder_rnns)) ] current_context_vec = inputs.data.new(B, self.in_features).zero_() - stopnet_rnn_hidden = inputs.data.new(B, - self.r * self.memory_dim).zero_() # attention states attention = inputs.data.new(B, T).zero_() attention_cum = inputs.data.new(B, T).zero_() - # Time first (T_decoder, B, memory_dim) - if memory is not None: - memory = memory.transpose(0, 1) outputs = [] attentions = [] stop_tokens = []