import os
import shutil
import unittest

import numpy as np
import torch
from torch.utils.data import DataLoader

from tests import get_tests_data_path, get_tests_output_path
from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig
from TTS.tts.datasets import TTSDataset, load_tts_samples
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor

# pylint: disable=unused-variable

OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/")
os.makedirs(OUTPATH, exist_ok=True)

# create a dummy config for testing data loaders.
c = BaseTTSConfig(text_cleaner="english_cleaners", num_loader_workers=0, batch_size=2, use_noise_augment=False)
c.r = 5
c.data_path = os.path.join(get_tests_data_path(), "ljspeech/")
ok_ljspeech = os.path.exists(c.data_path)

dataset_config = BaseDatasetConfig(
    formatter="ljspeech_test",  # ljspeech_test to multi-speaker
    meta_file_train="metadata.csv",
    meta_file_val=None,
    path=c.data_path,
    language="en",
)

DATA_EXIST = True
if not os.path.exists(c.data_path):
    DATA_EXIST = False

print(" > Dynamic data loader test: {}".format(DATA_EXIST))


class TestTTSDataset(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.max_loader_iter = 4
        self.ap = AudioProcessor(**c.audio)

    def _create_dataloader(self, batch_size, r, bgs, start_by_longest=False):
        # load dataset
        meta_data_train, meta_data_eval = load_tts_samples(dataset_config, eval_split=True, eval_split_size=0.2)
        items = meta_data_train + meta_data_eval

        tokenizer, _ = TTSTokenizer.init_from_config(c)
        dataset = TTSDataset(
            outputs_per_step=r,
            compute_linear_spec=True,
            return_wav=True,
            tokenizer=tokenizer,
            ap=self.ap,
            samples=items,
            batch_group_size=bgs,
            min_text_len=c.min_text_len,
            max_text_len=c.max_text_len,
            min_audio_len=c.min_audio_len,
            max_audio_len=c.max_audio_len,
            start_by_longest=start_by_longest,
        )
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=dataset.collate_fn,
            drop_last=True,
            num_workers=c.num_loader_workers,
        )
        return dataloader, dataset

    def test_loader(self):
        if ok_ljspeech:
            dataloader, dataset = self._create_dataloader(1, 1, 0)

            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                text_input = data["token_id"]
                _ = data["token_id_lengths"]
                speaker_name = data["speaker_names"]
                linear_input = data["linear"]
                mel_input = data["mel"]
                mel_lengths = data["mel_lengths"]
                _ = data["stop_targets"]
                _ = data["item_idxs"]
                wavs = data["waveform"]

                neg_values = text_input[text_input < 0]
                check_count = len(neg_values)

                # check basic conditions
                self.assertEqual(check_count, 0)
                self.assertEqual(linear_input.shape[0], mel_input.shape[0], c.batch_size)
                self.assertEqual(linear_input.shape[2], self.ap.fft_size // 2 + 1)
                self.assertEqual(mel_input.shape[2], c.audio["num_mels"])
                self.assertEqual(wavs.shape[1], mel_input.shape[1] * c.audio.hop_length)
                self.assertIsInstance(speaker_name[0], str)

                # make sure that the computed mels and the waveform match and correctly computed
                mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy())
                # remove padding in mel-spectrogram
                mel_dataloader = mel_input[0].T.numpy()[:, : mel_lengths[0]]
                # guarantee that both mel-spectrograms have the same size and that we will remove waveform padding
                mel_new = mel_new[:, : mel_lengths[0]]
                ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length)
                mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg]
                self.assertLess(abs(mel_diff.sum()), 1e-5)

                # check normalization ranges
                if self.ap.symmetric_norm:
                    self.assertLessEqual(mel_input.max(), self.ap.max_norm)
                    self.assertGreaterEqual(
                        mel_input.min(), -self.ap.max_norm  # pylint: disable=invalid-unary-operand-type
                    )
                    self.assertLess(mel_input.min(), 0)
                else:
                    self.assertLessEqual(mel_input.max(), self.ap.max_norm)
                    self.assertGreaterEqual(mel_input.min(), 0)

    def test_batch_group_shuffle(self):
        if ok_ljspeech:
            dataloader, dataset = self._create_dataloader(2, c.r, 16)
            last_length = 0
            frames = dataset.samples
            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                mel_lengths = data["mel_lengths"]
                avg_length = mel_lengths.numpy().mean()
            dataloader.dataset.preprocess_samples()
            is_items_reordered = False
            for idx, item in enumerate(dataloader.dataset.samples):
                if item != frames[idx]:
                    is_items_reordered = True
                    break
            self.assertGreaterEqual(avg_length, last_length)
            self.assertTrue(is_items_reordered)

    def test_start_by_longest(self):
        """Test start_by_longest option.

        Ther first item of the fist batch must be longer than all the other items.
        """
        if ok_ljspeech:
            dataloader, _ = self._create_dataloader(2, c.r, 0, True)
            dataloader.dataset.preprocess_samples()
            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                mel_lengths = data["mel_lengths"]
                if i == 0:
                    max_len = mel_lengths[0]
                print(mel_lengths)
                self.assertTrue(all(max_len >= mel_lengths))

    def test_padding_and_spectrograms(self):
        def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths):
            self.assertNotEqual(linear_input[idx, -1].sum(), 0)  # check padding
            self.assertNotEqual(linear_input[idx, -2].sum(), 0)
            self.assertNotEqual(mel_input[idx, -1].sum(), 0)
            self.assertNotEqual(mel_input[idx, -2].sum(), 0)
            self.assertEqual(stop_target[idx, -1], 1)
            self.assertEqual(stop_target[idx, -2], 0)
            self.assertEqual(stop_target[idx].sum(), 1)
            self.assertEqual(len(mel_lengths.shape), 1)
            self.assertEqual(mel_lengths[idx], linear_input[idx].shape[0])
            self.assertEqual(mel_lengths[idx], mel_input[idx].shape[0])

        if ok_ljspeech:
            dataloader, _ = self._create_dataloader(1, 1, 0)

            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                linear_input = data["linear"]
                mel_input = data["mel"]
                mel_lengths = data["mel_lengths"]
                stop_target = data["stop_targets"]
                item_idx = data["item_idxs"]

                # check mel_spec consistency
                wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
                mel = self.ap.melspectrogram(wav).astype("float32")
                mel = torch.FloatTensor(mel).contiguous()
                mel_dl = mel_input[0]
                # NOTE: Below needs to check == 0 but due to an unknown reason
                # there is a slight difference between two matrices.
                # TODO: Check this assert cond more in detail.
                self.assertLess(abs(mel.T - mel_dl).max(), 1e-5)

                # check mel-spec correctness
                mel_spec = mel_input[0].cpu().numpy()
                wav = self.ap.inv_melspectrogram(mel_spec.T)
                self.ap.save_wav(wav, OUTPATH + "/mel_inv_dataloader.wav")
                shutil.copy(item_idx[0], OUTPATH + "/mel_target_dataloader.wav")

                # check linear-spec
                linear_spec = linear_input[0].cpu().numpy()
                wav = self.ap.inv_spectrogram(linear_spec.T)
                self.ap.save_wav(wav, OUTPATH + "/linear_inv_dataloader.wav")
                shutil.copy(item_idx[0], OUTPATH + "/linear_target_dataloader.wav")

                # check the outputs
                check_conditions(0, linear_input, mel_input, stop_target, mel_lengths)

            # Test for batch size 2
            dataloader, _ = self._create_dataloader(2, 1, 0)

            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                linear_input = data["linear"]
                mel_input = data["mel"]
                mel_lengths = data["mel_lengths"]
                stop_target = data["stop_targets"]
                item_idx = data["item_idxs"]

                # set id to the longest sequence in the batch
                if mel_lengths[0] > mel_lengths[1]:
                    idx = 0
                else:
                    idx = 1

                # check the longer item in the batch
                check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths)

                # check the other item in the batch
                self.assertEqual(linear_input[1 - idx, -1].sum(), 0)
                self.assertEqual(mel_input[1 - idx, -1].sum(), 0)
                self.assertEqual(stop_target[1, mel_lengths[1] - 1], 1)
                self.assertEqual(stop_target[1, mel_lengths[1] :].sum(), stop_target.shape[1] - mel_lengths[1])
                self.assertEqual(len(mel_lengths.shape), 1)

                # check batch zero-frame conditions (zero-frame disabled)
                # assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
                # assert (mel_input * stop_target.unsqueeze(2)).sum() == 0