import torch
from torch import nn
from torch.nn import functional as F

from TTS.tts.layers.common_layers import Linear
from scipy.stats import betabinom


class LocationLayer(nn.Module):
    """Layers for Location Sensitive Attention

    Args:
        attention_dim (int): number of channels in the input tensor.
        attention_n_filters (int, optional): number of filters in convolution. Defaults to 32.
        attention_kernel_size (int, optional): kernel size of convolution filter. Defaults to 31.
    """
    def __init__(self,
                 attention_dim,
                 attention_n_filters=32,
                 attention_kernel_size=31):
        super(LocationLayer, self).__init__()
        self.location_conv1d = nn.Conv1d(
            in_channels=2,
            out_channels=attention_n_filters,
            kernel_size=attention_kernel_size,
            stride=1,
            padding=(attention_kernel_size - 1) // 2,
            bias=False)
        self.location_dense = Linear(
            attention_n_filters, attention_dim, bias=False, init_gain='tanh')

    def forward(self, attention_cat):
        """
        Shapes:
            attention_cat: [B, 2, C]
        """
        processed_attention = self.location_conv1d(attention_cat)
        processed_attention = self.location_dense(
            processed_attention.transpose(1, 2))
        return processed_attention


class GravesAttention(nn.Module):
    """Graves Attention as is ref1 with updates from ref2.
    ref1: https://arxiv.org/abs/1910.10288
    ref2: https://arxiv.org/pdf/1906.01083.pdf

    Args:
        query_dim (int): number of channels in query tensor.
        K (int): number of Gaussian heads to be used for computing attention.
    """
    COEF = 0.3989422917366028  # numpy.sqrt(1/(2*numpy.pi))

    def __init__(self, query_dim, K):

        super(GravesAttention, self).__init__()
        self._mask_value = 1e-8
        self.K = K
        # self.attention_alignment = 0.05
        self.eps = 1e-5
        self.J = None
        self.N_a = nn.Sequential(
            nn.Linear(query_dim, query_dim, bias=True),
            nn.ReLU(),
            nn.Linear(query_dim, 3*K, bias=True))
        self.attention_weights = None
        self.mu_prev = None
        self.init_layers()

    def init_layers(self):
        torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.)  # bias mean
        torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10)  # bias std

    def init_states(self, inputs):
        if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]:
            self.J = torch.arange(0, inputs.shape[1]+2.0).to(inputs.device) + 0.5
        self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
        self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)

    # pylint: disable=R0201
    # pylint: disable=unused-argument
    def preprocess_inputs(self, inputs):
        return None

    def forward(self, query, inputs, processed_inputs, mask):
        """
        Shapes:
            query: [B, C_attention_rnn]
            inputs: [B, T_in, C_encoder]
            processed_inputs: place_holder
            mask: [B, T_in]
        """
        gbk_t = self.N_a(query)
        gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K)

        # attention model parameters
        # each B x K
        g_t = gbk_t[:, 0, :]
        b_t = gbk_t[:, 1, :]
        k_t = gbk_t[:, 2, :]

        # dropout to decorrelate attention heads
        g_t = torch.nn.functional.dropout(g_t, p=0.5, training=self.training)

        # attention GMM parameters
        sig_t = torch.nn.functional.softplus(b_t) + self.eps

        mu_t = self.mu_prev + torch.nn.functional.softplus(k_t)
        g_t = torch.softmax(g_t, dim=-1) + self.eps

        j = self.J[:inputs.size(1)+1]

        # attention weights
        phi_t = g_t.unsqueeze(-1) * (1 / (1 + torch.sigmoid((mu_t.unsqueeze(-1) - j) / sig_t.unsqueeze(-1))))

        # discritize attention weights
        alpha_t = torch.sum(phi_t, 1)
        alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1]
        alpha_t[alpha_t == 0] = 1e-8

        # apply masking
        if mask is not None:
            alpha_t.data.masked_fill_(~mask, self._mask_value)

        context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1)
        self.attention_weights = alpha_t
        self.mu_prev = mu_t
        return context


class OriginalAttention(nn.Module):
    """Bahdanau Attention with various optional modifications. Proposed below.
    - Location sensitive attnetion: https://arxiv.org/abs/1712.05884
    - Forward Attention: https://arxiv.org/abs/1807.06736 + state masking at inference
    - Using sigmoid instead of softmax normalization
    - Attention windowing at inference time

    Note:
        Location Sensitive Attention is an attention mechanism that extends the additive attention mechanism
    to use cumulative attention weights from previous decoder time steps as an additional feature.

        Forward attention considers only the alignment paths that satisfy the monotonic condition at each
    decoder timestep. The modified attention probabilities at each timestep are computed recursively
    using a forward algorithm.

        Transition agent for forward attention is further proposed, which helps the attention mechanism
    to make decisions whether to move forward or stay at each decoder timestep.

        Attention windowing applies a sliding windows to time steps of the input tensor centering at the last
    time step with the largest attention weight. It is especially useful at inference to keep the attention
    alignment diagonal.


    Args:
        query_dim (int): number of channels in the query tensor.
        embedding_dim (int): number of channels in the vakue tensor. In general, the value tensor is the output of the encoder layer.
        attention_dim (int): number of channels of the inner attention layers.
        location_attention (bool): enable/disable location sensitive attention.
        attention_location_n_filters (int): number of location attention filters.
        attention_location_kernel_size (int): filter size of location attention convolution layer.
        windowing (int): window size for attention windowing. if it is 5, for computing the attention, it only considers the time steps [(t-5), ..., (t+5)] of the input.
        norm (str): normalization method applied to the attention weights. 'softmax' or 'sigmoid'
        forward_attn (bool): enable/disable forward attention.
        trans_agent (bool): enable/disable transition agent in the forward attention.
        forward_attn_mask (int): enable/disable an explicit masking in forward attention. It is useful to set at especially inference time.
    """
    # Pylint gets confused by PyTorch conventions here
    #pylint: disable=attribute-defined-outside-init
    def __init__(self, query_dim, embedding_dim, attention_dim,
                 location_attention, attention_location_n_filters,
                 attention_location_kernel_size, windowing, norm, forward_attn,
                 trans_agent, forward_attn_mask):
        super(OriginalAttention, self).__init__()
        self.query_layer = Linear(
            query_dim, attention_dim, bias=False, init_gain='tanh')
        self.inputs_layer = Linear(
            embedding_dim, attention_dim, bias=False, init_gain='tanh')
        self.v = Linear(attention_dim, 1, bias=True)
        if trans_agent:
            self.ta = nn.Linear(
                query_dim + embedding_dim, 1, bias=True)
        if location_attention:
            self.location_layer = LocationLayer(
                attention_dim,
                attention_location_n_filters,
                attention_location_kernel_size,
            )
        self._mask_value = -float("inf")
        self.windowing = windowing
        self.win_idx = None
        self.norm = norm
        self.forward_attn = forward_attn
        self.trans_agent = trans_agent
        self.forward_attn_mask = forward_attn_mask
        self.location_attention = location_attention

    def init_win_idx(self):
        self.win_idx = -1
        self.win_back = 2
        self.win_front = 6

    def init_forward_attn(self, inputs):
        B = inputs.shape[0]
        T = inputs.shape[1]
        self.alpha = torch.cat(
            [torch.ones([B, 1]),
             torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device)
        self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)

    def init_location_attention(self, inputs):
        B = inputs.size(0)
        T = inputs.size(1)
        self.attention_weights_cum = torch.zeros([B, T], device=inputs.device)

    def init_states(self, inputs):
        B = inputs.size(0)
        T = inputs.size(1)
        self.attention_weights = torch.zeros([B, T], device=inputs.device)
        if self.location_attention:
            self.init_location_attention(inputs)
        if self.forward_attn:
            self.init_forward_attn(inputs)
        if self.windowing:
            self.init_win_idx()

    def preprocess_inputs(self, inputs):
        return self.inputs_layer(inputs)

    def update_location_attention(self, alignments):
        self.attention_weights_cum += alignments

    def get_location_attention(self, query, processed_inputs):
        attention_cat = torch.cat((self.attention_weights.unsqueeze(1),
                                   self.attention_weights_cum.unsqueeze(1)),
                                  dim=1)
        processed_query = self.query_layer(query.unsqueeze(1))
        processed_attention_weights = self.location_layer(attention_cat)
        energies = self.v(
            torch.tanh(processed_query + processed_attention_weights +
                       processed_inputs))
        energies = energies.squeeze(-1)
        return energies, processed_query

    def get_attention(self, query, processed_inputs):
        processed_query = self.query_layer(query.unsqueeze(1))
        energies = self.v(torch.tanh(processed_query + processed_inputs))
        energies = energies.squeeze(-1)
        return energies, processed_query

    def apply_windowing(self, attention, inputs):
        back_win = self.win_idx - self.win_back
        front_win = self.win_idx + self.win_front
        if back_win > 0:
            attention[:, :back_win] = -float("inf")
        if front_win < inputs.shape[1]:
            attention[:, front_win:] = -float("inf")
        # this is a trick to solve a special problem.
        # but it does not hurt.
        if self.win_idx == -1:
            attention[:, 0] = attention.max()
        # Update the window
        self.win_idx = torch.argmax(attention, 1).long()[0].item()
        return attention

    def apply_forward_attention(self, alignment):
        # forward attention
        fwd_shifted_alpha = F.pad(
            self.alpha[:, :-1].clone().to(alignment.device), (1, 0, 0, 0))
        # compute transition potentials
        alpha = ((1 - self.u) * self.alpha
                 + self.u * fwd_shifted_alpha
                 + 1e-8) * alignment
        # force incremental alignment
        if not self.training and self.forward_attn_mask:
            _, n = fwd_shifted_alpha.max(1)
            val, _ = alpha.max(1)
            for b in range(alignment.shape[0]):
                alpha[b, n[b] + 3:] = 0
                alpha[b, :(
                    n[b] - 1
                )] = 0  # ignore all previous states to prevent repetition.
                alpha[b,
                      (n[b] - 2
                       )] = 0.01 * val[b]  # smoothing factor for the prev step
        # renormalize attention weights
        alpha = alpha / alpha.sum(dim=1, keepdim=True)
        return alpha

    def forward(self, query, inputs, processed_inputs, mask):
        """
        shapes:
            query: [B, C_attn_rnn]
            inputs: [B, T_en, D_en]
            processed_inputs: [B, T_en, D_attn]
            mask: [B, T_en]
        """
        if self.location_attention:
            attention, _ = self.get_location_attention(
                query, processed_inputs)
        else:
            attention, _ = self.get_attention(
                query, processed_inputs)
        # apply masking
        if mask is not None:
            attention.data.masked_fill_(~mask, self._mask_value)
        # apply windowing - only in eval mode
        if not self.training and self.windowing:
            attention = self.apply_windowing(attention, inputs)

        # normalize attention values
        if self.norm == "softmax":
            alignment = torch.softmax(attention, dim=-1)
        elif self.norm == "sigmoid":
            alignment = torch.sigmoid(attention) / torch.sigmoid(
                attention).sum(
                    dim=1, keepdim=True)
        else:
            raise ValueError("Unknown value for attention norm type")

        if self.location_attention:
            self.update_location_attention(alignment)

        # apply forward attention if enabled
        if self.forward_attn:
            alignment = self.apply_forward_attention(alignment)
            self.alpha = alignment

        context = torch.bmm(alignment.unsqueeze(1), inputs)
        context = context.squeeze(1)
        self.attention_weights = alignment

        # compute transition agent
        if self.forward_attn and self.trans_agent:
            ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
            self.u = torch.sigmoid(self.ta(ta_input))
        return context


class MonotonicDynamicConvolutionAttention(nn.Module):
    """Dynamic convolution attention from
    https://arxiv.org/pdf/1910.10288.pdf


    query -> linear -> tanh -> linear ->|
                                        |                                            mask values
                                        v                                              |    |
               atten_w(t-1) -|-> conv1d_dynamic -> linear -|-> tanh -> + -> softmax -> * -> * -> context
                             |-> conv1d_static  -> linear -|           |
                             |-> conv1d_prior   -> log ----------------|

    query: attention rnn output.

    Note:
        Dynamic convolution attention is an alternation of the location senstive attention with
    dynamically computed convolution filters from the previous attention scores and a set of
    constraints to keep the attention alignment diagonal.

    Args:
        query_dim (int): number of channels in the query tensor.
        embedding_dim (int): number of channels in the value tensor.
        static_filter_dim (int): number of channels in the convolution layer computing the static filters.
        static_kernel_size (int): kernel size for the convolution layer computing the static filters.
        dynamic_filter_dim (int): number of channels in the convolution layer computing the dynamic filters.
        dynamic_kernel_size (int): kernel size for the convolution layer computing the dynamic filters.
        prior_filter_len (int, optional): [description]. Defaults to 11 from the paper.
        alpha (float, optional): [description]. Defaults to 0.1 from the paper.
        beta (float, optional): [description]. Defaults to 0.9 from the paper.
    """
    def __init__(
        self,
        query_dim,
        embedding_dim,  # pylint: disable=unused-argument
        attention_dim,
        static_filter_dim,
        static_kernel_size,
        dynamic_filter_dim,
        dynamic_kernel_size,
        prior_filter_len=11,
        alpha=0.1,
        beta=0.9,
    ):
        super().__init__()
        self._mask_value = 1e-8
        self.dynamic_filter_dim = dynamic_filter_dim
        self.dynamic_kernel_size = dynamic_kernel_size
        self.prior_filter_len = prior_filter_len
        self.attention_weights = None
        # setup key and query layers
        self.query_layer = nn.Linear(query_dim, attention_dim)
        self.key_layer = nn.Linear(
            attention_dim, dynamic_filter_dim * dynamic_kernel_size, bias=False
        )
        self.static_filter_conv = nn.Conv1d(
            1,
            static_filter_dim,
            static_kernel_size,
            padding=(static_kernel_size - 1) // 2,
            bias=False,
        )
        self.static_filter_layer = nn.Linear(static_filter_dim, attention_dim, bias=False)
        self.dynamic_filter_layer = nn.Linear(dynamic_filter_dim, attention_dim)
        self.v = nn.Linear(attention_dim, 1, bias=False)

        prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1,
                          alpha, beta)
        self.register_buffer("prior", torch.FloatTensor(prior).flip(0))

    # pylint: disable=unused-argument
    def forward(self, query, inputs, processed_inputs, mask):
        """
        query: [B, C_attn_rnn]
        inputs: [B, T_en, D_en]
        processed_inputs: place holder.
        mask: [B, T_en]
        """
        # compute prior filters
        prior_filter = F.conv1d(
            F.pad(self.attention_weights.unsqueeze(1),
                  (self.prior_filter_len - 1, 0)), self.prior.view(1, 1, -1))
        prior_filter = torch.log(prior_filter.clamp_min_(1e-6)).squeeze(1)
        G = self.key_layer(torch.tanh(self.query_layer(query)))
        # compute dynamic filters
        dynamic_filter = F.conv1d(
            self.attention_weights.unsqueeze(0),
            G.view(-1, 1, self.dynamic_kernel_size),
            padding=(self.dynamic_kernel_size - 1) // 2,
            groups=query.size(0),
        )
        dynamic_filter = dynamic_filter.view(query.size(0), self.dynamic_filter_dim, -1).transpose(1, 2)
        # compute static filters
        static_filter = self.static_filter_conv(self.attention_weights.unsqueeze(1)).transpose(1, 2)
        alignment = self.v(
            torch.tanh(
                self.static_filter_layer(static_filter) +
                self.dynamic_filter_layer(dynamic_filter))).squeeze(-1) + prior_filter
        # compute attention weights
        attention_weights = F.softmax(alignment, dim=-1)
        # apply masking
        if mask is not None:
            attention_weights.data.masked_fill_(~mask, self._mask_value)
        self.attention_weights = attention_weights
        # compute context
        context = torch.bmm(attention_weights.unsqueeze(1), inputs).squeeze(1)
        return context

    def preprocess_inputs(self, inputs):  # pylint: disable=no-self-use
        return None

    def init_states(self, inputs):
        B = inputs.size(0)
        T = inputs.size(1)
        self.attention_weights = torch.zeros([B, T], device=inputs.device)
        self.attention_weights[:, 0] = 1.


def init_attn(attn_type, query_dim, embedding_dim, attention_dim,
              location_attention, attention_location_n_filters,
              attention_location_kernel_size, windowing, norm, forward_attn,
              trans_agent, forward_attn_mask, attn_K):
    if attn_type == "original":
        return OriginalAttention(query_dim, embedding_dim, attention_dim,
                                 location_attention,
                                 attention_location_n_filters,
                                 attention_location_kernel_size, windowing,
                                 norm, forward_attn, trans_agent,
                                 forward_attn_mask)
    if attn_type == "graves":
        return GravesAttention(query_dim, attn_K)
    if attn_type == "dynamic_convolution":
        return MonotonicDynamicConvolutionAttention(query_dim,
                                                    embedding_dim,
                                                    attention_dim,
                                                    static_filter_dim=8,
                                                    static_kernel_size=21,
                                                    dynamic_filter_dim=8,
                                                    dynamic_kernel_size=21,
                                                    prior_filter_len=11,
                                                    alpha=0.1,
                                                    beta=0.9)

    raise RuntimeError(
        " [!] Given Attention Type '{attn_type}' is not exist.")