import tensorflow as tf
from tensorflow import keras
from tensorflow.python.ops import math_ops
# from tensorflow_addons.seq2seq import BahdanauAttention

# NOTE: linter has a problem with the current TF release
#pylint: disable=no-value-for-parameter
#pylint: disable=unexpected-keyword-arg

class Linear(keras.layers.Layer):
    def __init__(self, units, use_bias, **kwargs):
        super(Linear, self).__init__(**kwargs)
        self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
        self.activation = keras.layers.ReLU()

    def call(self, x):
        """
        shapes:
            x: B x T x C
        """
        return self.activation(self.linear_layer(x))


class LinearBN(keras.layers.Layer):
    def __init__(self, units, use_bias, **kwargs):
        super(LinearBN, self).__init__(**kwargs)
        self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
        self.batch_normalization = keras.layers.BatchNormalization(axis=-1, momentum=0.90, epsilon=1e-5, name='batch_normalization')
        self.activation = keras.layers.ReLU()

    def call(self, x, training=None):
        """
        shapes:
            x: B x T x C
        """
        out = self.linear_layer(x)
        out = self.batch_normalization(out, training=training)
        return self.activation(out)


class Prenet(keras.layers.Layer):
    def __init__(self,
                 prenet_type,
                 prenet_dropout,
                 units,
                 bias,
                 **kwargs):
        super(Prenet, self).__init__(**kwargs)
        self.prenet_type = prenet_type
        self.prenet_dropout = prenet_dropout
        self.linear_layers = []
        if prenet_type == "bn":
            self.linear_layers += [LinearBN(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
        elif prenet_type == "original":
            self.linear_layers += [Linear(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
        else:
            raise RuntimeError(' [!] Unknown prenet type.')
        if prenet_dropout:
            self.dropout = keras.layers.Dropout(rate=0.5)

    def call(self, x, training=None):
        """
        shapes:
            x: B x T x C
        """
        for linear in self.linear_layers:
            if self.prenet_dropout:
                x = self.dropout(linear(x), training=training)
            else:
                x = linear(x)
        return x


def _sigmoid_norm(score):
    attn_weights = tf.nn.sigmoid(score)
    attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True)
    return attn_weights


class Attention(keras.layers.Layer):
    """TODO: implement forward_attention
    TODO: location sensitive attention
    TODO: implement attention windowing """
    def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters,
                 loc_attn_kernel_size, use_windowing, norm, use_forward_attn,
                 use_trans_agent, use_forward_attn_mask, **kwargs):
        super(Attention, self).__init__(**kwargs)
        self.use_loc_attn = use_loc_attn
        self.loc_attn_n_filters = loc_attn_n_filters
        self.loc_attn_kernel_size = loc_attn_kernel_size
        self.use_windowing = use_windowing
        self.norm = norm
        self.use_forward_attn = use_forward_attn
        self.use_trans_agent = use_trans_agent
        self.use_forward_attn_mask = use_forward_attn_mask
        self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name='query_layer/linear_layer')
        self.inputs_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name=f'{self.name}/inputs_layer/linear_layer')
        self.v = tf.keras.layers.Dense(1, use_bias=True, name='v/linear_layer')
        if use_loc_attn:
            self.location_conv1d = keras.layers.Conv1D(
                filters=loc_attn_n_filters,
                kernel_size=loc_attn_kernel_size,
                padding='same',
                use_bias=False,
                name='location_layer/location_conv1d')
            self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name='location_layer/location_dense')
        if norm == 'softmax':
            self.norm_func = tf.nn.softmax
        elif norm == 'sigmoid':
            self.norm_func = _sigmoid_norm
        else:
            raise ValueError("Unknown value for attention norm type")

    def init_states(self, batch_size, value_length):
        states = []
        if self.use_loc_attn:
            attention_cum = tf.zeros([batch_size, value_length])
            attention_old = tf.zeros([batch_size, value_length])
            states = [attention_cum, attention_old]
        if self.use_forward_attn:
            alpha = tf.concat([
                tf.ones([batch_size, 1]),
                tf.zeros([batch_size, value_length])[:, :-1] + 1e-7
            ], 1)
            states.append(alpha)
        return tuple(states)

    def process_values(self, values):
        """ cache values for decoder iterations """
        #pylint: disable=attribute-defined-outside-init
        self.processed_values = self.inputs_layer(values)
        self.values = values

    def get_loc_attn(self, query, states):
        """ compute location attention, query layer and
        unnorm. attention weights"""
        attention_cum, attention_old = states[:2]
        attn_cat = tf.stack([attention_old, attention_cum], axis=2)

        processed_query = self.query_layer(tf.expand_dims(query, 1))
        processed_attn = self.location_dense(self.location_conv1d(attn_cat))
        score = self.v(
            tf.nn.tanh(self.processed_values + processed_query +
                       processed_attn))
        score = tf.squeeze(score, axis=2)
        return score, processed_query

    def get_attn(self, query):
        """ compute query layer and unnormalized attention weights """
        processed_query = self.query_layer(tf.expand_dims(query, 1))
        score = self.v(tf.nn.tanh(self.processed_values + processed_query))
        score = tf.squeeze(score, axis=2)
        return score, processed_query

    def apply_score_masking(self, score, mask):  #pylint: disable=no-self-use
        """ ignore sequence paddings """
        padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
        # Bias so padding positions do not contribute to attention distribution.
        score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
        return score

    def apply_forward_attention(self, alignment, alpha):  #pylint: disable=no-self-use
        # forward attention
        fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0)
        # compute transition potentials
        new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment
        # renormalize attention weights
        new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True)
        return new_alpha

    def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None):
        states = []
        if self.use_loc_attn:
            states = [old_states[0] + scores_norm, attn_weights]
        if self.use_forward_attn:
            states.append(new_alpha)
        return tuple(states)

    def call(self, query, states):
        """
        shapes:
            query: B x D
        """
        if self.use_loc_attn:
            score, _ = self.get_loc_attn(query, states)
        else:
            score, _ = self.get_attn(query)

        # TODO: masking
        # if mask is not None:
        # self.apply_score_masking(score, mask)
        # attn_weights shape == (batch_size, max_length, 1)

        # normalize attention scores
        scores_norm = self.norm_func(score)
        attn_weights = scores_norm

        # apply forward attention
        new_alpha = None
        if self.use_forward_attn:
            new_alpha = self.apply_forward_attention(attn_weights, states[-1])
            attn_weights = new_alpha

        # update states tuple
        # states = (cum_attn_weights, attn_weights, new_alpha)
        states = self.update_states(states, scores_norm, attn_weights, new_alpha)

        # context_vector shape after sum == (batch_size, hidden_size)
        context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False)
        context_vector = tf.squeeze(context_vector, axis=1)
        return context_vector, attn_weights, states


# def _location_sensitive_score(processed_query, keys, processed_loc, attention_v, attention_b):
#     dtype = processed_query.dtype
#     num_units = keys.shape[-1].value or array_ops.shape(keys)[-1]
#     return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query + processed_loc + attention_b), [2])


# class LocationSensitiveAttention(BahdanauAttention):
#     def __init__(self,
#                  units,
#                  memory=None,
#                  memory_sequence_length=None,
#                  normalize=False,
#                  probability_fn="softmax",
#                  kernel_initializer="glorot_uniform",
#                  dtype=None,
#                  name="LocationSensitiveAttention",
#                  location_attention_filters=32,
#                  location_attention_kernel_size=31):

#         super(LocationSensitiveAttention,
#                     self).__init__(units=units,
#                                     memory=memory,
#                                     memory_sequence_length=memory_sequence_length,
#                                     normalize=normalize,
#                                     probability_fn='softmax',  ## parent module default
#                                     kernel_initializer=kernel_initializer,
#                                     dtype=dtype,
#                                     name=name)
#         if probability_fn == 'sigmoid':
#             self.probability_fn = lambda score, _: self._sigmoid_normalization(score)
#         self.location_conv = keras.layers.Conv1D(filters=location_attention_filters, kernel_size=location_attention_kernel_size, padding='same', use_bias=False)
#         self.location_dense = keras.layers.Dense(units, use_bias=False)
#         # self.v = keras.layers.Dense(1, use_bias=True)

#     def  _location_sensitive_score(self, processed_query, keys, processed_loc):
#         processed_query = tf.expand_dims(processed_query, 1)
#         return tf.reduce_sum(self.attention_v * tf.tanh(keys + processed_query + processed_loc), [2])

#     def _location_sensitive(self, alignment_cum, alignment_old):
#         alignment_cat = tf.stack([alignment_cum, alignment_old], axis=2)
#         return self.location_dense(self.location_conv(alignment_cat))

#     def _sigmoid_normalization(self, score):
#         return tf.nn.sigmoid(score) / tf.reduce_sum(tf.nn.sigmoid(score), axis=-1, keepdims=True)

#     # def _apply_masking(self, score, mask):
#     #     padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
#     #     # Bias so padding positions do not contribute to attention distribution.
#     #     score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
#     #     return score

#     def _calculate_attention(self, query, state):
#         alignment_cum, alignment_old = state[:2]
#         processed_query = self.query_layer(
#             query) if self.query_layer else query
#         processed_loc = self._location_sensitive(alignment_cum, alignment_old)
#         score = self._location_sensitive_score(
#             processed_query,
#             self.keys,
#             processed_loc)
#         alignment = self.probability_fn(score, state)
#         alignment_cum = alignment_cum + alignment
#         state[0] = alignment_cum
#         state[1] = alignment
#         return alignment, state

#     def compute_context(self, alignments):
#         expanded_alignments = tf.expand_dims(alignments, 1)
#         context = tf.matmul(expanded_alignments, self.values)
#         context = tf.squeeze(context, [1])
#         return context

#     # def call(self, query, state):
#     #     alignment, next_state = self._calculate_attention(query, state)
#     #     return alignment, next_state