diff --git a/layers/tacotron.py b/layers/tacotron.py index 496a86c8..0e601512 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -2,8 +2,6 @@ import torch from torch import nn from .attention import AttentionRNN -from .attention import get_mask_from_lengths - class Prenet(nn.Module): r""" Prenet as explained at https://arxiv.org/abs/1703.10135. @@ -270,8 +268,8 @@ class Decoder(nn.Module): for _ in range(len(self.decoder_rnns))] current_context_vec = inputs.data.new(B, 256).zero_() stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_() - attention_vec = memory.data.new(B, T).zero_() - attention_vec_cum = memory.data.new(B, T).zero_() + attention_vec = inputs.data.new(B, T).zero_() + attention_vec_cum = inputs.data.new(B, T).zero_() # Time first (T_decoder, B, memory_dim) if memory is not None: memory = memory.transpose(0, 1) @@ -290,12 +288,11 @@ class Decoder(nn.Module): processed_memory = self.prenet(memory_input) # Attention RNN attention_vec_cat = torch.cat((attention_vec.unsqueeze(1), - attention_vec_cum.unsqueeze(1)), + attention_vec_cum.unsqueeze(1) / (t + 1)), dim=1) attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_vec_cat) attention_vec_cum += attention_vec - attention_vec_cum /= (t + 1) # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( torch.cat((attention_rnn_hidden, current_context_vec), -1))