import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F


class BahdanauAttention(nn.Module):
    def __init__(self, dim):
        super(BahdanauAttention, self).__init__()
        self.query_layer = nn.Linear(dim, dim, bias=False)
        self.tanh = nn.Tanh()
        self.v = nn.Linear(dim, 1, bias=False)

    def forward(self, query, processed_inputs):
        """
        Args:
            query: (batch, 1, dim) or (batch, dim)
            processed_inputs: (batch, max_time, dim)
        """
        if query.dim() == 2:
            # insert time-axis for broadcasting
            query = query.unsqueeze(1)
        # (batch, 1, dim)
        processed_query = self.query_layer(query)

        # (batch, max_time, 1)
        alignment = self.v(self.tanh(processed_query + processed_inputs))

        # (batch, max_time)
        return alignment.squeeze(-1)


def get_mask_from_lengths(inputs, inputs_lengths):
    """Get mask tensor from list of length

    Args:
        inputs: (batch, max_time, dim)
        inputs_lengths: array like
    """
    mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_()
    for idx, l in enumerate(inputs_lengths):
        mask[idx][:l] = 1
    return ~mask


class AttentionWrapper(nn.Module):
    def __init__(self, rnn_cell, alignment_model,
                 score_mask_value=-float("inf")):
        super(AttentionWrapper, self).__init__()
        self.rnn_cell = rnn_cell
        self.alignment_model = alignment_model
        self.score_mask_value = score_mask_value

    def forward(self, query, context_vec, cell_state, inputs,
                processed_inputs=None, mask=None, inputs_lengths=None):

        if processed_inputs is None:
            processed_inputs = inputs

        if inputs_lengths is not None and mask is None:
            mask = get_mask_from_lengths(inputs, inputs_lengths)

        # Alignment
        # (batch, max_time)
        # e_{ij} = a(s_{i-1}, h_j)
        # import ipdb
        # ipdb.set_trace()
        alignment = self.alignment_model(cell_state, processed_inputs)

        if mask is not None:
            mask = mask.view(query.size(0), -1)
            alignment.data.masked_fill_(mask, self.score_mask_value)

        # Normalize context_vec weight
        alignment = F.softmax(alignment, dim=-1)

        # Attention context vector
        # (batch, 1, dim)
        # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
        context_vec = torch.bmm(alignment.unsqueeze(1), inputs)
        context_vec = context_vec.squeeze(1)

        # Concat input query and previous context_vec context
        cell_input = torch.cat((query, context_vec), -1)
        #cell_input = cell_input.unsqueeze(1)

        # Feed it to RNN
        # s_i = f(y_{i-1}, c_{i}, s_{i-1})
        cell_output = self.rnn_cell(cell_input, cell_state)

        context_vec = context_vec.squeeze(1)
        return cell_output, context_vec, alignment