import tensorflow as tf
from tensorflow import keras
from TTS.tf.utils.tf_utils import shape_list
from TTS.tf.layers.common_layers import Prenet, Attention
# from tensorflow_addons.seq2seq import AttentionWrapper


class ConvBNBlock(keras.layers.Layer):
    def __init__(self, filters, kernel_size, activation, **kwargs):
        super(ConvBNBlock, self).__init__(**kwargs)
        self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding='same', name='convolution1d')
        self.batch_normalization = keras.layers.BatchNormalization(axis=2, momentum=0.90, epsilon=1e-5, name='batch_normalization')
        self.dropout = keras.layers.Dropout(rate=0.5, name='dropout')
        self.activation = keras.layers.Activation(activation, name='activation')

    def call(self, x, training=None):
        o = self.convolution1d(x)
        o = self.batch_normalization(o, training=training)
        o = self.activation(o)
        o = self.dropout(o, training=training)
        return o


class Postnet(keras.layers.Layer):
    def __init__(self, output_filters, num_convs, **kwargs):
        super(Postnet, self).__init__(**kwargs)
        self.convolutions = []
        self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name='convolutions_0'))
        for idx in range(1, num_convs - 1):
            self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name=f'convolutions_{idx}'))
        self.convolutions.append(ConvBNBlock(output_filters, 5, 'linear', name=f'convolutions_{idx+1}'))

    def call(self, x, training=None):
        o = x
        for layer in self.convolutions:
            o = layer(o, training=training)
        return o


class Encoder(keras.layers.Layer):
    def __init__(self, output_input_dim, **kwargs):
        super(Encoder, self).__init__(**kwargs)
        self.convolutions = []
        for idx in range(3):
            self.convolutions.append(ConvBNBlock(output_input_dim, 5, 'relu', name=f'convolutions_{idx}'))
        self.lstm = keras.layers.Bidirectional(keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name='lstm')

    def call(self, x, training=None):
        o = x
        for layer in self.convolutions:
            o = layer(o, training=training)
        o = self.lstm(o)
        return o


class Decoder(keras.layers.Layer):
    #pylint: disable=unused-argument
    def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type,
                 prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask,
                 use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, **kwargs):
        super(Decoder, self).__init__(**kwargs)
        self.frame_dim = frame_dim
        self.r_init = tf.constant(r, dtype=tf.int32)
        self.r = tf.constant(r, dtype=tf.int32)
        self.separate_stopnet = separate_stopnet
        self.max_decoder_steps = tf.constant(1000, dtype=tf.int32)
        self.stop_thresh = tf.constant(0.5, dtype=tf.float32)

        # model dimensions
        self.query_dim = 1024
        self.decoder_rnn_dim = 1024
        self.prenet_dim = 256
        self.attn_dim = 128
        self.p_attention_dropout = 0.1
        self.p_decoder_dropout = 0.1

        self.prenet = Prenet(prenet_type,
                             prenet_dropout,
                             [self.prenet_dim, self.prenet_dim],
                             bias=False,
                             name='prenet')
        self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name=f'{self.name}/attention_rnn', )
        self.attention_rnn_dropout = keras.layers.Dropout(0.5)

        # TODO: implement other attn options
        self.attention = Attention(attn_dim=self.attn_dim,
                                   use_loc_attn=True,
                                   loc_attn_n_filters=32,
                                   loc_attn_kernel_size=31,
                                   use_windowing=False,
                                   norm=attn_norm,
                                   use_forward_attn=use_forward_attn,
                                   use_trans_agent=use_trans_agent,
                                   use_forward_attn_mask=use_forward_attn_mask,
                                   name='attention')
        self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name=f'{self.name}/decoder_rnn')
        self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
        self.linear_projection = keras.layers.Dense(self.frame_dim * r, name=f'{self.name}/linear_projection/linear_layer')
        self.stopnet = keras.layers.Dense(1, name=f'{self.name}/stopnet/linear_layer')


    def set_max_decoder_steps(self, new_max_steps):
        self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32)

    def set_r(self, new_r):
        self.r = tf.constant(new_r, dtype=tf.int32)

    def build_decoder_initial_states(self, batch_size, memory_dim, memory_length):
        zero_frame = tf.zeros([batch_size, self.frame_dim])
        zero_context = tf.zeros([batch_size, memory_dim])
        attention_rnn_state = self.attention_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32)
        decoder_rnn_state = self.decoder_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32)
        attention_states = self.attention.init_states(batch_size, memory_length)
        return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states

    def step(self, prenet_next, states,
             memory_seq_length=None, training=None):
        _, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states
        attention_rnn_input = tf.concat([prenet_next, context_next], -1)
        attention_rnn_output, attention_rnn_state = \
                self.attention_rnn(attention_rnn_input,
                                   attention_rnn_state, training=training)
        attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training)
        context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training)
        decoder_rnn_input = tf.concat([attention_rnn_output, context], -1)
        decoder_rnn_output, decoder_rnn_state = \
                self.decoder_rnn(decoder_rnn_input, decoder_rnn_state, training=training)
        decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training)
        linear_projection_input = tf.concat([decoder_rnn_output, context], -1)
        output_frame = self.linear_projection(linear_projection_input, training=training)
        stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1)
        stopnet_output = self.stopnet(stopnet_input, training=training)
        output_frame = output_frame[:, :self.r * self.frame_dim]
        states = (output_frame[:, self.frame_dim * (self.r - 1):], context, attention_rnn_state, decoder_rnn_state, attention_states)
        return output_frame, stopnet_output, states, attention

    def decode(self, memory, states, frames, memory_seq_length=None):
        B, _, _ = shape_list(memory)
        num_iter = shape_list(frames)[1] // self.r
        # init states
        frame_zero = tf.expand_dims(states[0], 1)
        frames = tf.concat([frame_zero, frames], axis=1)
        outputs = tf.TensorArray(dtype=tf.float32, size=num_iter)
        attentions = tf.TensorArray(dtype=tf.float32, size=num_iter)
        stop_tokens = tf.TensorArray(dtype=tf.float32, size=num_iter)
        # pre-computes
        self.attention.process_values(memory)
        prenet_output = self.prenet(frames, training=True)
        step_count = tf.constant(0, dtype=tf.int32)

        def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions):
            prenet_next = prenet_output[:, step]
            output, stop_token, states, attention = self.step(prenet_next,
                                                              states,
                                                              memory_seq_length)
            outputs = outputs.write(step, output)
            attentions = attentions.write(step, attention)
            stop_tokens = stop_tokens.write(step, stop_token)
            return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions
        _, memory, _, states, outputs, stop_tokens, attentions = \
                tf.while_loop(lambda *arg: True,
                              _body,
                              loop_vars=(step_count, memory, prenet_output,
                                         states, outputs, stop_tokens, attentions),
                              parallel_iterations=32,
                              swap_memory=True,
                              maximum_iterations=num_iter)

        outputs = outputs.stack()
        attentions = attentions.stack()
        stop_tokens = stop_tokens.stack()
        outputs = tf.transpose(outputs, [1, 0, 2])
        attentions = tf.transpose(attentions, [1, 0, 2])
        stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
        stop_tokens = tf.squeeze(stop_tokens, axis=2)
        outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
        return outputs, stop_tokens, attentions

    def decode_inference(self, memory, states):
        B, _, _ = shape_list(memory)
        # init states
        outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
        attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
        stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
        # pre-computes
        self.attention.process_values(memory)

        # iter vars
        stop_flag = tf.constant(False, dtype=tf.bool)
        step_count = tf.constant(0, dtype=tf.int32)

        def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag):
            frame_next = states[0]
            prenet_next = self.prenet(frame_next, training=False)
            output, stop_token, states, attention = self.step(prenet_next,
                                                              states,
                                                              None,
                                                              training=False)
            stop_token = tf.math.sigmoid(stop_token)
            outputs = outputs.write(step, output)
            attentions = attentions.write(step, attention)
            stop_tokens = stop_tokens.write(step, stop_token)
            stop_flag = tf.greater(stop_token, self.stop_thresh)
            stop_flag = tf.reduce_all(stop_flag)
            return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag

        cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
        _, memory, states, outputs, stop_tokens, attentions, stop_flag = \
                tf.while_loop(cond,
                              _body,
                              loop_vars=(step_count, memory, states, outputs,
                                         stop_tokens, attentions, stop_flag),
                              parallel_iterations=32,
                              swap_memory=True,
                              maximum_iterations=self.max_decoder_steps)

        outputs = outputs.stack()
        attentions = attentions.stack()
        stop_tokens = stop_tokens.stack()

        outputs = tf.transpose(outputs, [1, 0, 2])
        attentions = tf.transpose(attentions, [1, 0, 2])
        stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
        stop_tokens = tf.squeeze(stop_tokens, axis=2)
        outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
        return outputs, stop_tokens, attentions

    def call(self, memory, states, frames=None, memory_seq_length=None, training=False):
        if training:
            return self.decode(memory, states, frames, memory_seq_length)
        return self.decode_inference(memory, states)