diff --git a/layers/tacotron.py b/layers/tacotron.py index 01fed238..f6ab85e7 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -252,7 +252,7 @@ class Decoder(nn.Module): stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_() # attention states attention = inputs.data.new(B, T).zero_() - # attention_cum = inputs.data.new(B, T).zero_() + attention_cum = inputs.data.new(B, T).zero_() # Time first (T_decoder, B, memory_dim) if memory is not None: memory = memory.transpose(0, 1) @@ -270,13 +270,13 @@ class Decoder(nn.Module): # Prenet processed_memory = self.prenet(memory_input) # Attention RNN - # attention_cat = torch.cat((attention.unsqueeze(1), - # attention_cum.unsqueeze(1)), - # dim=1) + attention_cat = torch.cat((attention.unsqueeze(1), + attention_cum.unsqueeze(1)), + dim=1) attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( processed_memory, current_context_vec, attention_rnn_hidden, - inputs, attention.unsqueeze(1), input_lens) - # attention_cum += attention + inputs, attention_cat, input_lens) + attention_cum += attention # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( torch.cat((attention_rnn_hidden, current_context_vec), -1))