import functools

import torch

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.languages import get_language_weighted_sampler

# Fixing random state to avoid random fails
torch.manual_seed(0)

dataset_config_en = BaseDatasetConfig(
    name="ljspeech",
    meta_file_train="metadata.csv",
    meta_file_val="metadata.csv",
    path="tests/data/ljspeech",
    language="en",
)

dataset_config_pt = BaseDatasetConfig(
    name="ljspeech",
    meta_file_train="metadata.csv",
    meta_file_val="metadata.csv",
    path="tests/data/ljspeech",
    language="pt-br",
)

# Adding the EN samples twice to create an unbalanced dataset
train_samples, eval_samples = load_tts_samples(
    [dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
)


def is_balanced(lang_1, lang_2):
    return 0.85 < lang_1 / lang_2 < 1.2


random_sampler = torch.utils.data.RandomSampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
    if train_samples[index]["language"] == "en":
        en += 1
    else:
        pt += 1

assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"

weighted_sampler = get_language_weighted_sampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
    if train_samples[index]["language"] == "en":
        en += 1
    else:
        pt += 1

assert is_balanced(en, pt), "Weighted sampler is supposed to be balanced"