import tensorflow as tf
from tensorflow import keras

from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet
from TTS.tf.utils.tf_utils import shape_list


#pylint: disable=too-many-ancestors
class Tacotron2(keras.models.Model):
    def __init__(self,
                 num_chars,
                 num_speakers,
                 r,
                 postnet_output_dim=80,
                 decoder_output_dim=80,
                 attn_type='original',
                 attn_win=False,
                 attn_norm="softmax",
                 attn_K=4,
                 prenet_type="original",
                 prenet_dropout=True,
                 forward_attn=False,
                 trans_agent=False,
                 forward_attn_mask=False,
                 location_attn=True,
                 separate_stopnet=True,
                 bidirectional_decoder=False,
                 enable_tflite=False):
        super(Tacotron2, self).__init__()
        self.r = r
        self.decoder_output_dim = decoder_output_dim
        self.postnet_output_dim = postnet_output_dim
        self.bidirectional_decoder = bidirectional_decoder
        self.num_speakers = num_speakers
        self.speaker_embed_dim = 256
        self.enable_tflite = enable_tflite

        self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding')
        self.encoder = Encoder(512, name='encoder')
        # TODO: most of the decoder args have no use at the momment
        self.decoder = Decoder(decoder_output_dim,
                               r,
                               attn_type=attn_type,
                               use_attn_win=attn_win,
                               attn_norm=attn_norm,
                               prenet_type=prenet_type,
                               prenet_dropout=prenet_dropout,
                               use_forward_attn=forward_attn,
                               use_trans_agent=trans_agent,
                               use_forward_attn_mask=forward_attn_mask,
                               use_location_attn=location_attn,
                               attn_K=attn_K,
                               separate_stopnet=separate_stopnet,
                               speaker_emb_dim=self.speaker_embed_dim,
                               name='decoder',
                               enable_tflite=enable_tflite)
        self.postnet = Postnet(postnet_output_dim, 5, name='postnet')

    @tf.function(experimental_relax_shapes=True)
    def call(self, characters, text_lengths=None, frames=None, training=None):
        if training:
            return self.training(characters, text_lengths, frames)
        if not training:
            return self.inference(characters)
        raise RuntimeError(' [!] Set model training mode True or False')

    def training(self, characters, text_lengths, frames):
        B, T = shape_list(characters)
        embedding_vectors = self.embedding(characters, training=True)
        encoder_output = self.encoder(embedding_vectors, training=True)
        decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
        decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, frames, text_lengths, training=True)
        postnet_frames = self.postnet(decoder_frames, training=True)
        output_frames = decoder_frames + postnet_frames
        return decoder_frames, output_frames, attentions, stop_tokens

    def inference(self, characters):
        B, T = shape_list(characters)
        embedding_vectors = self.embedding(characters, training=False)
        encoder_output = self.encoder(embedding_vectors, training=False)
        decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
        decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
        postnet_frames = self.postnet(decoder_frames, training=False)
        output_frames = decoder_frames + postnet_frames
        print(output_frames.shape)
        return decoder_frames, output_frames, attentions, stop_tokens

    @tf.function(
        experimental_relax_shapes=True,
        input_signature=[
            tf.TensorSpec([1, None], dtype=tf.int32),
        ],)
    def inference_tflite(self, characters):
        B, T = shape_list(characters)
        embedding_vectors = self.embedding(characters, training=False)
        encoder_output = self.encoder(embedding_vectors, training=False)
        decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
        decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
        postnet_frames = self.postnet(decoder_frames, training=False)
        output_frames = decoder_frames + postnet_frames
        print(output_frames.shape)
        return decoder_frames, output_frames, attentions, stop_tokens

    def build_inference(self, ):
        # TODO: issue https://github.com/PyCQA/pylint/issues/3613
        input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32)  #pylint: disable=unexpected-keyword-arg
        self(input_ids)