import functools
import random
import unittest

import torch

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.languages import get_language_balancer_weights
from TTS.tts.utils.speakers import get_speaker_balancer_weights
from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler

# Fixing random state to avoid random fails
torch.manual_seed(0)

dataset_config_en = BaseDatasetConfig(
    formatter="ljspeech",
    meta_file_train="metadata.csv",
    meta_file_val="metadata.csv",
    path="tests/data/ljspeech",
    language="en",
)

dataset_config_pt = BaseDatasetConfig(
    formatter="ljspeech",
    meta_file_train="metadata.csv",
    meta_file_val="metadata.csv",
    path="tests/data/ljspeech",
    language="pt-br",
)

# Adding the EN samples twice to create a language unbalanced dataset
train_samples, eval_samples = load_tts_samples(
    [dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
)

# gerenate a speaker unbalanced dataset
for i, sample in enumerate(train_samples):
    if i < 5:
        sample["speaker_name"] = "ljspeech-0"
    else:
        sample["speaker_name"] = "ljspeech-1"


def is_balanced(lang_1, lang_2):
    return 0.85 < lang_1 / lang_2 < 1.2


class TestSamplers(unittest.TestCase):
    def test_language_random_sampler(self):  # pylint: disable=no-self-use
        random_sampler = torch.utils.data.RandomSampler(train_samples)
        ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
        en, pt = 0, 0
        for index in ids:
            if train_samples[index]["language"] == "en":
                en += 1
            else:
                pt += 1

        assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"

    def test_language_weighted_random_sampler(self):  # pylint: disable=no-self-use
        weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            get_language_balancer_weights(train_samples), len(train_samples)
        )
        ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
        en, pt = 0, 0
        for index in ids:
            if train_samples[index]["language"] == "en":
                en += 1
            else:
                pt += 1

        assert is_balanced(en, pt), "Language Weighted sampler is supposed to be balanced"

    def test_speaker_weighted_random_sampler(self):  # pylint: disable=no-self-use
        weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            get_speaker_balancer_weights(train_samples), len(train_samples)
        )
        ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
        spk1, spk2 = 0, 0
        for index in ids:
            if train_samples[index]["speaker_name"] == "ljspeech-0":
                spk1 += 1
            else:
                spk2 += 1

        assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"

    def test_perfect_sampler(self):  # pylint: disable=no-self-use
        classes = set()
        for item in train_samples:
            classes.add(item["speaker_name"])

        sampler = PerfectBatchSampler(
            train_samples,
            classes,
            batch_size=2 * 3,  # total batch size
            num_classes_in_batch=2,
            label_key="speaker_name",
            shuffle=False,
            drop_last=True,
        )
        batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
        for batch in batchs:
            spk1, spk2 = 0, 0
            # for in each batch
            for index in batch:
                if train_samples[index]["speaker_name"] == "ljspeech-0":
                    spk1 += 1
                else:
                    spk2 += 1
            assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"

    def test_perfect_sampler_shuffle(self):  # pylint: disable=no-self-use
        classes = set()
        for item in train_samples:
            classes.add(item["speaker_name"])

        sampler = PerfectBatchSampler(
            train_samples,
            classes,
            batch_size=2 * 3,  # total batch size
            num_classes_in_batch=2,
            label_key="speaker_name",
            shuffle=True,
            drop_last=False,
        )
        batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
        for batch in batchs:
            spk1, spk2 = 0, 0
            # for in each batch
            for index in batch:
                if train_samples[index]["speaker_name"] == "ljspeech-0":
                    spk1 += 1
                else:
                    spk2 += 1
            assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"

    def test_length_weighted_random_sampler(self):  # pylint: disable=no-self-use
        for _ in range(1000):
            # gerenate a lenght unbalanced dataset with random max/min audio lenght
            min_audio = random.randrange(1, 22050)
            max_audio = random.randrange(44100, 220500)
            for idx, item in enumerate(train_samples):
                # increase the diversity of durations
                random_increase = random.randrange(100, 1000)
                if idx < 5:
                    item["audio_length"] = min_audio + random_increase
                else:
                    item["audio_length"] = max_audio + random_increase

            weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
                get_length_balancer_weights(train_samples, num_buckets=2), len(train_samples)
            )
            ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
            len1, len2 = 0, 0
            for index in ids:
                if train_samples[index]["audio_length"] < max_audio:
                    len1 += 1
                else:
                    len2 += 1
            assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"

    def test_bucket_batch_sampler(self):
        bucket_size_multiplier = 2
        sampler = range(len(train_samples))
        sampler = BucketBatchSampler(
            sampler,
            data=train_samples,
            batch_size=7,
            drop_last=True,
            sort_key=lambda x: len(x["text"]),
            bucket_size_multiplier=bucket_size_multiplier,
        )

        # check if the samples are sorted by text lenght whuile bucketing
        min_text_len_in_bucket = 0
        bucket_items = []
        for batch_idx, batch in enumerate(list(sampler)):
            if (batch_idx + 1) % bucket_size_multiplier == 0:
                for bucket_item in bucket_items:
                    self.assertLessEqual(min_text_len_in_bucket, len(train_samples[bucket_item]["text"]))
                    min_text_len_in_bucket = len(train_samples[bucket_item]["text"])
                min_text_len_in_bucket = 0
                bucket_items = []
            else:
                bucket_items += batch

        # check sampler length
        self.assertEqual(len(sampler), len(train_samples) // 7)