import importlib
import re
from collections import Counter

import numpy as np
import torch

from TTS.utils.generic_utils import check_argument


def split_dataset(items):
    speakers = [item[-1] for item in items]
    is_multi_speaker = len(set(speakers)) > 1
    eval_split_size = min(500, int(len(items) * 0.01))
    assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
    np.random.seed(0)
    np.random.shuffle(items)
    if is_multi_speaker:
        items_eval = []
        speakers = [item[-1] for item in items]
        speaker_counter = Counter(speakers)
        while len(items_eval) < eval_split_size:
            item_idx = np.random.randint(0, len(items))
            speaker_to_be_removed = items[item_idx][-1]
            if speaker_counter[speaker_to_be_removed] > 1:
                items_eval.append(items[item_idx])
                speaker_counter[speaker_to_be_removed] -= 1
                del items[item_idx]
        return items_eval, items
    return items[:eval_split_size], items[eval_split_size:]


# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.data.max()
    seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
    # B x T_max
    return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)


def to_camel(text):
    text = text.capitalize()
    text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
    text = text.replace("Tts", "TTS")
    return text


def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
    print(" > Using model: {}".format(c.model))
    MyModel = importlib.import_module("TTS.tts.models." + c.model.lower())
    MyModel = getattr(MyModel, to_camel(c.model))
    if c.model.lower() in "tacotron":
        model = MyModel(
            num_chars=num_chars + getattr(c, "add_blank", False),
            num_speakers=num_speakers,
            r=c.r,
            postnet_output_dim=int(c.audio["fft_size"] / 2 + 1),
            decoder_output_dim=c.audio["num_mels"],
            gst=c.use_gst,
            gst_embedding_dim=c.gst["gst_embedding_dim"],
            gst_num_heads=c.gst["gst_num_heads"],
            gst_style_tokens=c.gst["gst_style_tokens"],
            gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"],
            memory_size=c.memory_size,
            attn_type=c.attention_type,
            attn_win=c.windowing,
            attn_norm=c.attention_norm,
            prenet_type=c.prenet_type,
            prenet_dropout=c.prenet_dropout,
            prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False,
            forward_attn=c.use_forward_attn,
            trans_agent=c.transition_agent,
            forward_attn_mask=c.forward_attn_mask,
            location_attn=c.location_attn,
            attn_K=c.attention_heads,
            separate_stopnet=c.separate_stopnet,
            bidirectional_decoder=c.bidirectional_decoder,
            double_decoder_consistency=c.double_decoder_consistency,
            ddc_r=c.ddc_r,
            speaker_embedding_dim=speaker_embedding_dim,
        )
    elif c.model.lower() == "tacotron2":
        model = MyModel(
            num_chars=num_chars + getattr(c, "add_blank", False),
            num_speakers=num_speakers,
            r=c.r,
            postnet_output_dim=c.audio["num_mels"],
            decoder_output_dim=c.audio["num_mels"],
            gst=c.use_gst,
            gst_embedding_dim=c.gst["gst_embedding_dim"],
            gst_num_heads=c.gst["gst_num_heads"],
            gst_style_tokens=c.gst["gst_style_tokens"],
            gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"],
            attn_type=c.attention_type,
            attn_win=c.windowing,
            attn_norm=c.attention_norm,
            prenet_type=c.prenet_type,
            prenet_dropout=c.prenet_dropout,
            prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False,
            forward_attn=c.use_forward_attn,
            trans_agent=c.transition_agent,
            forward_attn_mask=c.forward_attn_mask,
            location_attn=c.location_attn,
            attn_K=c.attention_heads,
            separate_stopnet=c.separate_stopnet,
            bidirectional_decoder=c.bidirectional_decoder,
            double_decoder_consistency=c.double_decoder_consistency,
            ddc_r=c.ddc_r,
            speaker_embedding_dim=speaker_embedding_dim,
        )
    elif c.model.lower() == "glow_tts":
        model = MyModel(
            num_chars=num_chars + getattr(c, "add_blank", False),
            hidden_channels_enc=c["hidden_channels_encoder"],
            hidden_channels_dec=c["hidden_channels_decoder"],
            hidden_channels_dp=c["hidden_channels_duration_predictor"],
            out_channels=c.audio["num_mels"],
            encoder_type=c.encoder_type,
            encoder_params=c.encoder_params,
            use_encoder_prenet=c["use_encoder_prenet"],
            num_flow_blocks_dec=12,
            kernel_size_dec=5,
            dilation_rate=1,
            num_block_layers=4,
            dropout_p_dec=0.05,
            num_speakers=num_speakers,
            c_in_channels=0,
            num_splits=4,
            num_squeeze=2,
            sigmoid_scale=False,
            mean_only=True,
            external_speaker_embedding_dim=speaker_embedding_dim,
        )
    elif c.model.lower() == "speedy_speech":
        model = MyModel(
            num_chars=num_chars + getattr(c, "add_blank", False),
            out_channels=c.audio["num_mels"],
            hidden_channels=c["hidden_channels"],
            positional_encoding=c["positional_encoding"],
            encoder_type=c["encoder_type"],
            encoder_params=c["encoder_params"],
            decoder_type=c["decoder_type"],
            decoder_params=c["decoder_params"],
            c_in_channels=0,
        )
    elif c.model.lower() == "align_tts":
        model = MyModel(
            num_chars=num_chars + getattr(c, "add_blank", False),
            out_channels=c.audio["num_mels"],
            hidden_channels=c["hidden_channels"],
            hidden_channels_dp=c["hidden_channels_dp"],
            encoder_type=c["encoder_type"],
            encoder_params=c["encoder_params"],
            decoder_type=c["decoder_type"],
            decoder_params=c["decoder_params"],
            c_in_channels=0,
        )
    return model


def is_tacotron(c):
    return "tacotron" in c["model"].lower()


def check_config_tts(c):
    check_argument(
        "model",
        c,
        enum_list=["tacotron", "tacotron2", "glow_tts", "speedy_speech", "align_tts"],
        restricted=True,
        val_type=str,
    )
    check_argument("run_name", c, restricted=True, val_type=str)
    check_argument("run_description", c, val_type=str)

    # AUDIO
    check_argument("audio", c, restricted=True, val_type=dict)

    # audio processing parameters
    check_argument("num_mels", c["audio"], restricted=True, val_type=int, min_val=10, max_val=2056)
    check_argument("fft_size", c["audio"], restricted=True, val_type=int, min_val=128, max_val=4058)
    check_argument("sample_rate", c["audio"], restricted=True, val_type=int, min_val=512, max_val=100000)
    check_argument(
        "frame_length_ms",
        c["audio"],
        restricted=True,
        val_type=float,
        min_val=10,
        max_val=1000,
        alternative="win_length",
    )
    check_argument(
        "frame_shift_ms", c["audio"], restricted=True, val_type=float, min_val=1, max_val=1000, alternative="hop_length"
    )
    check_argument("preemphasis", c["audio"], restricted=True, val_type=float, min_val=0, max_val=1)
    check_argument("min_level_db", c["audio"], restricted=True, val_type=int, min_val=-1000, max_val=10)
    check_argument("ref_level_db", c["audio"], restricted=True, val_type=int, min_val=0, max_val=1000)
    check_argument("power", c["audio"], restricted=True, val_type=float, min_val=1, max_val=5)
    check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000)

    # vocabulary parameters
    check_argument("characters", c, restricted=False, val_type=dict)
    check_argument(
        "pad", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
    )
    check_argument(
        "eos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
    )
    check_argument(
        "bos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
    )
    check_argument(
        "characters",
        c["characters"] if "characters" in c.keys() else {},
        restricted="characters" in c.keys(),
        val_type=str,
    )
    check_argument(
        "phonemes",
        c["characters"] if "characters" in c.keys() else {},
        restricted="characters" in c.keys() and c["use_phonemes"],
        val_type=str,
    )
    check_argument(
        "punctuations",
        c["characters"] if "characters" in c.keys() else {},
        restricted="characters" in c.keys(),
        val_type=str,
    )

    # normalization parameters
    check_argument("signal_norm", c["audio"], restricted=True, val_type=bool)
    check_argument("symmetric_norm", c["audio"], restricted=True, val_type=bool)
    check_argument("max_norm", c["audio"], restricted=True, val_type=float, min_val=0.1, max_val=1000)
    check_argument("clip_norm", c["audio"], restricted=True, val_type=bool)
    check_argument("mel_fmin", c["audio"], restricted=True, val_type=float, min_val=0.0, max_val=1000)
    check_argument("mel_fmax", c["audio"], restricted=True, val_type=float, min_val=500.0)
    check_argument("spec_gain", c["audio"], restricted=True, val_type=[int, float], min_val=1, max_val=100)
    check_argument("do_trim_silence", c["audio"], restricted=True, val_type=bool)
    check_argument("trim_db", c["audio"], restricted=True, val_type=int)

    # training parameters
    check_argument("batch_size", c, restricted=True, val_type=int, min_val=1)
    check_argument("eval_batch_size", c, restricted=True, val_type=int, min_val=1)
    check_argument("r", c, restricted=True, val_type=int, min_val=1)
    check_argument("gradual_training", c, restricted=False, val_type=list)
    check_argument("mixed_precision", c, restricted=False, val_type=bool)
    # check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)

    # loss parameters
    check_argument("loss_masking", c, restricted=True, val_type=bool)
    if c["model"].lower() in ["tacotron", "tacotron2"]:
        check_argument("decoder_loss_alpha", c, restricted=True, val_type=float, min_val=0)
        check_argument("postnet_loss_alpha", c, restricted=True, val_type=float, min_val=0)
        check_argument("postnet_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0)
        check_argument("decoder_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0)
        check_argument("decoder_ssim_alpha", c, restricted=True, val_type=float, min_val=0)
        check_argument("postnet_ssim_alpha", c, restricted=True, val_type=float, min_val=0)
        check_argument("ga_alpha", c, restricted=True, val_type=float, min_val=0)
    if c["model"].lower in ["speedy_speech", "align_tts"]:
        check_argument("ssim_alpha", c, restricted=True, val_type=float, min_val=0)
        check_argument("l1_alpha", c, restricted=True, val_type=float, min_val=0)
        check_argument("huber_alpha", c, restricted=True, val_type=float, min_val=0)

    # validation parameters
    check_argument("run_eval", c, restricted=True, val_type=bool)
    check_argument("test_delay_epochs", c, restricted=True, val_type=int, min_val=0)
    check_argument("test_sentences_file", c, restricted=False, val_type=str)

    # optimizer
    check_argument("noam_schedule", c, restricted=False, val_type=bool)
    check_argument("grad_clip", c, restricted=True, val_type=float, min_val=0.0)
    check_argument("epochs", c, restricted=True, val_type=int, min_val=1)
    check_argument("lr", c, restricted=True, val_type=float, min_val=0)
    check_argument("wd", c, restricted=is_tacotron(c), val_type=float, min_val=0)
    check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0)
    check_argument("seq_len_norm", c, restricted=is_tacotron(c), val_type=bool)

    # tacotron prenet
    check_argument("memory_size", c, restricted=is_tacotron(c), val_type=int, min_val=-1)
    check_argument("prenet_type", c, restricted=is_tacotron(c), val_type=str, enum_list=["original", "bn"])
    check_argument("prenet_dropout", c, restricted=is_tacotron(c), val_type=bool)

    # attention
    check_argument(
        "attention_type",
        c,
        restricted=is_tacotron(c),
        val_type=str,
        enum_list=["graves", "original", "dynamic_convolution"],
    )
    check_argument("attention_heads", c, restricted=is_tacotron(c), val_type=int)
    check_argument("attention_norm", c, restricted=is_tacotron(c), val_type=str, enum_list=["sigmoid", "softmax"])
    check_argument("windowing", c, restricted=is_tacotron(c), val_type=bool)
    check_argument("use_forward_attn", c, restricted=is_tacotron(c), val_type=bool)
    check_argument("forward_attn_mask", c, restricted=is_tacotron(c), val_type=bool)
    check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool)
    check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool)
    check_argument("location_attn", c, restricted=is_tacotron(c), val_type=bool)
    check_argument("bidirectional_decoder", c, restricted=is_tacotron(c), val_type=bool)
    check_argument("double_decoder_consistency", c, restricted=is_tacotron(c), val_type=bool)
    check_argument("ddc_r", c, restricted="double_decoder_consistency" in c.keys(), min_val=1, max_val=7, val_type=int)

    if c["model"].lower() in ["tacotron", "tacotron2"]:
        # stopnet
        check_argument("stopnet", c, restricted=is_tacotron(c), val_type=bool)
        check_argument("separate_stopnet", c, restricted=is_tacotron(c), val_type=bool)

    # Model Parameters for non-tacotron models
    if c["model"].lower in ["speedy_speech", "align_tts"]:
        check_argument("positional_encoding", c, restricted=True, val_type=type)
        check_argument("encoder_type", c, restricted=True, val_type=str)
        check_argument("encoder_params", c, restricted=True, val_type=dict)
        check_argument("decoder_residual_conv_bn_params", c, restricted=True, val_type=dict)

    # GlowTTS parameters
    check_argument("encoder_type", c, restricted=not is_tacotron(c), val_type=str)

    # tensorboard
    check_argument("print_step", c, restricted=True, val_type=int, min_val=1)
    check_argument("tb_plot_step", c, restricted=True, val_type=int, min_val=1)
    check_argument("save_step", c, restricted=True, val_type=int, min_val=1)
    check_argument("checkpoint", c, restricted=True, val_type=bool)
    check_argument("tb_model_param_stats", c, restricted=True, val_type=bool)

    # dataloading
    # pylint: disable=import-outside-toplevel
    from TTS.tts.utils.text import cleaners

    check_argument("text_cleaner", c, restricted=True, val_type=str, enum_list=dir(cleaners))
    check_argument("enable_eos_bos_chars", c, restricted=True, val_type=bool)
    check_argument("num_loader_workers", c, restricted=True, val_type=int, min_val=0)
    check_argument("num_val_loader_workers", c, restricted=True, val_type=int, min_val=0)
    check_argument("batch_group_size", c, restricted=True, val_type=int, min_val=0)
    check_argument("min_seq_len", c, restricted=True, val_type=int, min_val=0)
    check_argument("max_seq_len", c, restricted=True, val_type=int, min_val=10)
    check_argument("compute_input_seq_cache", c, restricted=True, val_type=bool)

    # paths
    check_argument("output_path", c, restricted=True, val_type=str)

    # multi-speaker and gst
    check_argument("use_speaker_embedding", c, restricted=True, val_type=bool)
    check_argument("use_external_speaker_embedding_file", c, restricted=c["use_speaker_embedding"], val_type=bool)
    check_argument(
        "external_speaker_embedding_file", c, restricted=c["use_external_speaker_embedding_file"], val_type=str
    )
    if c["model"].lower() in ["tacotron", "tacotron2"] and c["use_gst"]:
        check_argument("use_gst", c, restricted=is_tacotron(c), val_type=bool)
        check_argument("gst", c, restricted=is_tacotron(c), val_type=dict)
        check_argument("gst_style_input", c["gst"], restricted=is_tacotron(c), val_type=[str, dict])
        check_argument("gst_embedding_dim", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
        check_argument("gst_use_speaker_embedding", c["gst"], restricted=is_tacotron(c), val_type=bool)
        check_argument("gst_num_heads", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
        check_argument("gst_style_tokens", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)

    # datasets - checking only the first entry
    check_argument("datasets", c, restricted=True, val_type=list)
    for dataset_entry in c["datasets"]:
        check_argument("name", dataset_entry, restricted=True, val_type=str)
        check_argument("path", dataset_entry, restricted=True, val_type=str)
        check_argument("meta_file_train", dataset_entry, restricted=True, val_type=[str, list])
        check_argument("meta_file_val", dataset_entry, restricted=True, val_type=str)