diff --git a/layers/tacotron2.py b/layers/tacotron2.py index c0aeafda..1833b0eb 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -131,7 +131,7 @@ class Attention(nn.Module): embedding_dim, attention_dim, bias=False, init_gain='tanh') self.v = Linear(attention_dim, 1, bias=True) if trans_agent: - self.ta = nn.Linear(attention_dim + embedding_dim, 1, bias=True) + self.ta = nn.Linear(attention_rnn_dim + embedding_dim, 1, bias=True) if location_attention: self.location_layer = LocationLayer(attention_location_n_filters, attention_location_kernel_size, @@ -208,7 +208,7 @@ class Attention(nn.Module): self.win_idx = torch.argmax(attention, 1).long()[0].item() return attention - def apply_forward_attention(self, inputs, alignment, processed_query): + def apply_forward_attention(self, inputs, alignment, query): # forward attention prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device) alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment @@ -218,7 +218,7 @@ class Attention(nn.Module): context = context.squeeze(1) # compute transition agent if self.trans_agent: - ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1) + ta_input = torch.cat([context, query.squeeze(1)], dim=-1) self.u = torch.sigmoid(self.ta(ta_input)) return context, self.alpha @@ -248,7 +248,7 @@ class Attention(nn.Module): self.update_location_attention(alignment) # apply forward attention if enabled if self.forward_attn: - context, self.attention_weights = self.apply_forward_attention(inputs, alignment, processed_query) + context, self.attention_weights = self.apply_forward_attention(inputs, alignment, attention_hidden_state) else: context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1)