import os
import re
import sys
import xml.etree.ElementTree as ET
from collections import Counter
from glob import glob
from pathlib import Path
from typing import List

import numpy as np
from tqdm import tqdm

####################
# UTILITIES
####################


def split_dataset(items):
    speakers = [item[-1] for item in items]
    is_multi_speaker = len(set(speakers)) > 1
    eval_split_size = min(500, int(len(items) * 0.01))
    assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
    np.random.seed(0)
    np.random.shuffle(items)
    if is_multi_speaker:
        items_eval = []
        speakers = [item[-1] for item in items]
        speaker_counter = Counter(speakers)
        while len(items_eval) < eval_split_size:
            item_idx = np.random.randint(0, len(items))
            speaker_to_be_removed = items[item_idx][-1]
            if speaker_counter[speaker_to_be_removed] > 1:
                items_eval.append(items[item_idx])
                speaker_counter[speaker_to_be_removed] -= 1
                del items[item_idx]
        return items_eval, items
    return items[:eval_split_size], items[eval_split_size:]


def load_meta_data(datasets, eval_split=True):
    meta_data_train_all = []
    meta_data_eval_all = [] if eval_split else None
    for dataset in datasets:
        name = dataset["name"]
        root_path = dataset["path"]
        meta_file_train = dataset["meta_file_train"]
        meta_file_val = dataset["meta_file_val"]
        # setup the right data processor
        preprocessor = get_preprocessor_by_name(name)
        # load train set
        meta_data_train = preprocessor(root_path, meta_file_train)
        print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
        # load evaluation split if set
        if eval_split:
            if meta_file_val:
                meta_data_eval = preprocessor(root_path, meta_file_val)
            else:
                meta_data_eval, meta_data_train = split_dataset(meta_data_train)
            meta_data_eval_all += meta_data_eval
        meta_data_train_all += meta_data_train
        # load attention masks for duration predictor training
        if dataset.meta_file_attn_mask:
            meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
            for idx, ins in enumerate(meta_data_train_all):
                attn_file = meta_data[ins[1]].strip()
                meta_data_train_all[idx].append(attn_file)
            if meta_data_eval_all:
                for idx, ins in enumerate(meta_data_eval_all):
                    attn_file = meta_data[ins[1]].strip()
                    meta_data_eval_all[idx].append(attn_file)
    return meta_data_train_all, meta_data_eval_all


def load_attention_mask_meta_data(metafile_path):
    """Load meta data file created by compute_attention_masks.py"""
    with open(metafile_path, "r") as f:
        lines = f.readlines()

    meta_data = []
    for line in lines:
        wav_file, attn_file = line.split("|")
        meta_data.append([wav_file, attn_file])
    return meta_data


def get_preprocessor_by_name(name):
    """Returns the respective preprocessing function."""
    thismodule = sys.modules[__name__]
    return getattr(thismodule, name.lower())


########################
# DATASETS
########################


def tweb(root_path, meta_file):
    """Normalize TWEB dataset.
    https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset
    """
    txt_file = os.path.join(root_path, meta_file)
    items = []
    speaker_name = "tweb"
    with open(txt_file, "r") as ttf:
        for line in ttf:
            cols = line.split("\t")
            wav_file = os.path.join(root_path, cols[0] + ".wav")
            text = cols[1]
            items.append([text, wav_file, speaker_name])
    return items


def mozilla(root_path, meta_file):
    """Normalizes Mozilla meta data files to TTS format"""
    txt_file = os.path.join(root_path, meta_file)
    items = []
    speaker_name = "mozilla"
    with open(txt_file, "r") as ttf:
        for line in ttf:
            cols = line.split("|")
            wav_file = cols[1].strip()
            text = cols[0].strip()
            wav_file = os.path.join(root_path, "wavs", wav_file)
            items.append([text, wav_file, speaker_name])
    return items


def mozilla_de(root_path, meta_file):
    """Normalizes Mozilla meta data files to TTS format"""
    txt_file = os.path.join(root_path, meta_file)
    items = []
    speaker_name = "mozilla"
    with open(txt_file, "r", encoding="ISO 8859-1") as ttf:
        for line in ttf:
            cols = line.strip().split("|")
            wav_file = cols[0].strip()
            text = cols[1].strip()
            folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
            wav_file = os.path.join(root_path, folder_name, wav_file)
            items.append([text, wav_file, speaker_name])
    return items


def mailabs(root_path, meta_files=None):
    """Normalizes M-AI-Labs meta data files to TTS format"""
    speaker_regex = re.compile("by_book/(male|female)/(?P<speaker_name>[^/]+)/")
    if meta_files is None:
        csv_files = glob(root_path + "/**/metadata.csv", recursive=True)
    else:
        csv_files = meta_files
    # meta_files = [f.strip() for f in meta_files.split(",")]
    items = []
    for csv_file in csv_files:
        txt_file = os.path.join(root_path, csv_file)
        folder = os.path.dirname(txt_file)
        # determine speaker based on folder structure...
        speaker_name_match = speaker_regex.search(txt_file)
        if speaker_name_match is None:
            continue
        speaker_name = speaker_name_match.group("speaker_name")
        print(" | > {}".format(csv_file))
        with open(txt_file, "r") as ttf:
            for line in ttf:
                cols = line.split("|")
                if meta_files is None:
                    wav_file = os.path.join(folder, "wavs", cols[0] + ".wav")
                else:
                    wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
                if os.path.isfile(wav_file):
                    text = cols[1].strip()
                    items.append([text, wav_file, speaker_name])
                else:
                    raise RuntimeError("> File %s does not exist!" % (wav_file))
    return items


def ljspeech(root_path, meta_file):
    """Normalizes the LJSpeech meta data file to TTS format
    https://keithito.com/LJ-Speech-Dataset/"""
    txt_file = os.path.join(root_path, meta_file)
    items = []
    speaker_name = "ljspeech"
    with open(txt_file, "r", encoding="utf-8") as ttf:
        for line in ttf:
            cols = line.split("|")
            wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
            text = cols[1]
            items.append([text, wav_file, speaker_name])
    return items


def sam_accenture(root_path, meta_file):
    """Normalizes the sam-accenture meta data file to TTS format
    https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files"""
    xml_file = os.path.join(root_path, "voice_over_recordings", meta_file)
    xml_root = ET.parse(xml_file).getroot()
    items = []
    speaker_name = "sam_accenture"
    for item in xml_root.findall("./fileid"):
        text = item.text
        wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav")
        if not os.path.exists(wav_file):
            print(f" [!] {wav_file} in metafile does not exist. Skipping...")
            continue
        items.append([text, wav_file, speaker_name])
    return items


def ruslan(root_path, meta_file):
    """Normalizes the RUSLAN meta data file to TTS format
    https://ruslan-corpus.github.io/"""
    txt_file = os.path.join(root_path, meta_file)
    items = []
    speaker_name = "ljspeech"
    with open(txt_file, "r", encoding="utf-8") as ttf:
        for line in ttf:
            cols = line.split("|")
            wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
            text = cols[1]
            items.append([text, wav_file, speaker_name])
    return items


def css10(root_path, meta_file):
    """Normalizes the CSS10 dataset file to TTS format"""
    txt_file = os.path.join(root_path, meta_file)
    items = []
    speaker_name = "ljspeech"
    with open(txt_file, "r") as ttf:
        for line in ttf:
            cols = line.split("|")
            wav_file = os.path.join(root_path, cols[0])
            text = cols[1]
            items.append([text, wav_file, speaker_name])
    return items


def nancy(root_path, meta_file):
    """Normalizes the Nancy meta data file to TTS format"""
    txt_file = os.path.join(root_path, meta_file)
    items = []
    speaker_name = "nancy"
    with open(txt_file, "r") as ttf:
        for line in ttf:
            utt_id = line.split()[1]
            text = line[line.find('"') + 1 : line.rfind('"') - 1]
            wav_file = os.path.join(root_path, "wavn", utt_id + ".wav")
            items.append([text, wav_file, speaker_name])
    return items


def common_voice(root_path, meta_file):
    """Normalize the common voice meta data file to TTS format."""
    txt_file = os.path.join(root_path, meta_file)
    items = []
    with open(txt_file, "r") as ttf:
        for line in ttf:
            if line.startswith("client_id"):
                continue
            cols = line.split("\t")
            text = cols[2]
            speaker_name = cols[0]
            wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
            items.append([text, wav_file, "MCV_" + speaker_name])
    return items


def libri_tts(root_path, meta_files=None):
    """https://ai.google/tools/datasets/libri-tts/"""
    items = []
    if meta_files is None:
        meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True)
    for meta_file in meta_files:
        _meta_file = os.path.basename(meta_file).split(".")[0]
        speaker_name = _meta_file.split("_")[0]
        chapter_id = _meta_file.split("_")[1]
        _root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}")
        with open(meta_file, "r") as ttf:
            for line in ttf:
                cols = line.split("\t")
                wav_file = os.path.join(_root_path, cols[0] + ".wav")
                text = cols[1]
                items.append([text, wav_file, "LTTS_" + speaker_name])
    for item in items:
        assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
    return items


def custom_turkish(root_path, meta_file):
    txt_file = os.path.join(root_path, meta_file)
    items = []
    speaker_name = "turkish-female"
    skipped_files = []
    with open(txt_file, "r", encoding="utf-8") as ttf:
        for line in ttf:
            cols = line.split("|")
            wav_file = os.path.join(root_path, "wavs", cols[0].strip() + ".wav")
            if not os.path.exists(wav_file):
                skipped_files.append(wav_file)
                continue
            text = cols[1].strip()
            items.append([text, wav_file, speaker_name])
    print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
    return items


# ToDo: add the dataset link when the dataset is released publicly
def brspeech(root_path, meta_file):
    """BRSpeech 3.0 beta"""
    txt_file = os.path.join(root_path, meta_file)
    items = []
    with open(txt_file, "r") as ttf:
        for line in ttf:
            if line.startswith("wav_filename"):
                continue
            cols = line.split("|")
            wav_file = os.path.join(root_path, cols[0])
            text = cols[2]
            speaker_name = cols[3]
            items.append([text, wav_file, speaker_name])
    return items


def vctk(root_path, meta_files=None, wavs_path="wav48"):
    """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
    test_speakers = meta_files
    items = []
    meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
    for meta_file in meta_files:
        _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
        file_id = txt_file.split(".")[0]
        if isinstance(test_speakers, list):  # if is list ignore this speakers ids
            if speaker_id in test_speakers:
                continue
        with open(meta_file) as file_text:
            text = file_text.readlines()[0]
        wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
        items.append([text, wav_file, "VCTK_" + speaker_id])

    return items


def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
    """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
    items = []
    txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
    for text_file in txt_files:
        _, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep)
        file_id = txt_file.split(".")[0]
        if isinstance(meta_files, list):  # if is list ignore this speakers ids
            if speaker_id in meta_files:
                continue
        wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
        items.append([None, wav_file, "VCTK_" + speaker_id])

    return items


# ======================================== VOX CELEB ===========================================
def voxceleb2(root_path, meta_file=None):
    """
    :param meta_file   Used only for consistency with load_meta_data api
    """
    return _voxcel_x(root_path, meta_file, voxcel_idx="2")


def voxceleb1(root_path, meta_file=None):
    """
    :param meta_file   Used only for consistency with load_meta_data api
    """
    return _voxcel_x(root_path, meta_file, voxcel_idx="1")


def _voxcel_x(root_path, meta_file, voxcel_idx):
    assert voxcel_idx in ["1", "2"]
    expected_count = 148_000 if voxcel_idx == "1" else 1_000_000
    voxceleb_path = Path(root_path)
    cache_to = voxceleb_path / f"metafile_voxceleb{voxcel_idx}.csv"
    cache_to.parent.mkdir(exist_ok=True)

    # if not exists meta file, crawl recursively for 'wav' files
    if meta_file is not None:
        with open(str(meta_file), "r") as f:
            return [x.strip().split("|") for x in f.readlines()]

    elif not cache_to.exists():
        cnt = 0
        meta_data = []
        wav_files = voxceleb_path.rglob("**/*.wav")
        for path in tqdm(
            wav_files,
            desc=f"Building VoxCeleb {voxcel_idx} Meta file ... this needs to be done only once.",
            total=expected_count,
        ):
            speaker_id = str(Path(path).parent.parent.stem)
            assert speaker_id.startswith("id")
            text = None  # VoxCel does not provide transciptions, and they are not needed for training the SE
            meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n")
            cnt += 1
        with open(str(cache_to), "w") as f:
            f.write("".join(meta_data))
        if cnt < expected_count:
            raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}")

    with open(str(cache_to), "r") as f:
        return [x.strip().split("|") for x in f.readlines()]


def baker(root_path: str, meta_file: str) -> List[List[str]]:
    """Normalizes the Baker meta data file to TTS format

    Args:
        root_path (str): path to the baker dataset
        meta_file (str): name of the meta dataset containing names of wav to select and the transcript of the sentence
    Returns:
        List[List[str]]: List of (text, wav_path, speaker_name) associated with each sentences
    """
    txt_file = os.path.join(root_path, meta_file)
    items = []
    speaker_name = "baker"
    with open(txt_file, "r") as ttf:
        for line in ttf:
            wav_name, text = line.rstrip("\n").split("|")
            wav_path = os.path.join(root_path, "clips_22", wav_name)
            items.append([text, wav_path, speaker_name])
    return items


def kokoro(root_path, meta_file):
    """Japanese single-speaker dataset from https://github.com/kaiidams/Kokoro-Speech-Dataset"""
    txt_file = os.path.join(root_path, meta_file)
    items = []
    speaker_name = "kokoro"
    with open(txt_file, "r") as ttf:
        for line in ttf:
            cols = line.split("|")
            wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
            text = cols[2].replace(" ", "")
            items.append([text, wav_file, speaker_name])
    return items