Modularize memory reshaping in decoder layer

This commit is contained in:
Eren Golge 2018-11-28 16:31:29 +01:00
parent bb2a88a984
commit cdaaff9dbb
1 changed files with 13 additions and 13 deletions

View File

@ -337,6 +337,17 @@ class Decoder(nn.Module):
self.proj_to_mel.weight,
gain=torch.nn.init.calculate_gain('linear'))
def _reshape_memory(self, memory):
B = memory.shape[0]
if memory is not None:
# Grouping multiple frames if necessary
if memory.size(-1) == self.memory_dim:
memory = memory.contiguous()
memory = memory.view(B, memory.size(1) // self.r, -1)
# Time first (T_decoder, B, memory_dim)
memory = memory.transpose(0, 1)
return memory
def forward(self, inputs, memory=None, mask=None):
"""
Decoder forward step.
@ -359,14 +370,8 @@ class Decoder(nn.Module):
T = inputs.size(1)
# Run greedy decoding if memory is None
greedy = not self.training
if memory is not None:
# Grouping multiple frames if necessary
if memory.size(-1) == self.memory_dim:
memory = memory.contiguous()
memory = memory.view(B, memory.size(1) // self.r, -1)
" !! Dimension mismatch {} vs {} * {}".format(
memory.size(-1), self.memory_dim, self.r)
T_decoder = memory.size(1)
memory = self._reshape_memory(memory)
T_decoder = memory.size(0)
# go frame as zeros matrix
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
# decoder states
@ -376,14 +381,9 @@ class Decoder(nn.Module):
for _ in range(len(self.decoder_rnns))
]
current_context_vec = inputs.data.new(B, self.in_features).zero_()
stopnet_rnn_hidden = inputs.data.new(B,
self.r * self.memory_dim).zero_()
# attention states
attention = inputs.data.new(B, T).zero_()
attention_cum = inputs.data.new(B, T).zero_()
# Time first (T_decoder, B, memory_dim)
if memory is not None:
memory = memory.transpose(0, 1)
outputs = []
attentions = []
stop_tokens = []