import os
import re
import glob
import shutil
import datetime
import json
import subprocess
import importlib
import pickle
import numpy as np
from collections import OrderedDict, Counter
import tensorflow as tf


def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs):
    checkpoint_path = 'tts_tf_checkpoint_{}.pkl'.format(current_step)
    checkpoint_path = os.path.join(output_folder, checkpoint_path)
    state = {
        'model': model.weights,
        'optimizer': optimizer,
        'step': current_step,
        'epoch': epoch,
        'date': datetime.date.today().strftime("%B %d, %Y"),
        'r': r
    }
    state.update(kwargs)
    pickle.dump(state, open(checkpoint_path, 'wb'))


def load_checkpoint(model, checkpoint_path):
    checkpoint = pickle.load(open(checkpoint_path, 'rb'))
    chkp_var_dict = dict([(var.name, var.numpy()) for var in checkpoint['model']])
    tf_vars = model.weights
    for tf_var in tf_vars:
        layer_name = tf_var.name
        chkp_var_value = chkp_var_dict[layer_name]
        tf.keras.backend.set_value(tf_var, chkp_var_value)
    if 'r' in checkpoint.keys():
        model.decoder.set_r(checkpoint['r'])
    return model


def sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.max()
    batch_size = sequence_length.size(0)
    seq_range = np.empty([0, max_len], dtype=np.int8)
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    if sequence_length.is_cuda:
        seq_range_expand = seq_range_expand.cuda()
    seq_length_expand = (
        sequence_length.unsqueeze(1).expand_as(seq_range_expand))
    # B x T_max
    return seq_range_expand < seq_length_expand


# @tf.custom_gradient
def check_gradient(x, grad_clip):
    x_normed = tf.clip_by_norm(x, grad_clip)
    grad_norm = tf.norm(grad_clip)
    return x_normed, grad_norm


def count_parameters(model, c):
    try:
        return model.count_params()
    except:
        input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype('int32'))
        input_lengths = np.random.randint(100, 129, (8, ))
        input_lengths[-1] = 128
        input_lengths = tf.convert_to_tensor(input_lengths.astype('int32'))
        mel_spec = np.random.rand(8, 2 * c.r,
                                  c.audio['num_mels']).astype('float32')
        mel_spec = tf.convert_to_tensor(mel_spec)
        speaker_ids = np.random.randint(
            0, 5, (8, )) if c.use_speaker_embedding else None
        _ = model(input_dummy, input_lengths, mel_spec)
        return model.count_params()


def setup_model(num_chars, num_speakers, c):
    print(" > Using model: {}".format(c.model))
    MyModel = importlib.import_module('TTS.tf.models.' + c.model.lower())
    MyModel = getattr(MyModel, c.model)
    if c.model.lower() in "tacotron":
        raise NotImplemented(' [!] Tacotron model is not ready.')
    elif c.model.lower() == "tacotron2":
        model = MyModel(num_chars=num_chars,
                             num_speakers=num_speakers,
                             r=c.r,
                             postnet_output_dim=c.audio['num_mels'],
                             decoder_output_dim=c.audio['num_mels'],
                             attn_type=c.attention_type,
                             attn_win=c.windowing,
                             attn_norm=c.attention_norm,
                             prenet_type=c.prenet_type,
                             prenet_dropout=c.prenet_dropout,
                             forward_attn=c.use_forward_attn,
                             trans_agent=c.transition_agent,
                             forward_attn_mask=c.forward_attn_mask,
                             location_attn=c.location_attn,
                             attn_K=c.attention_heads,
                             separate_stopnet=c.separate_stopnet,
                             bidirectional_decoder=c.bidirectional_decoder)
    return model