import os
import numpy as np
import collections
import torch
import random
from torch.utils.data import Dataset

from TTS.utils.text import text_to_sequence, phoneme_to_sequence, pad_with_eos_bos
from TTS.utils.data import prepare_data, prepare_tensor, prepare_stop_target


class MyDataset(Dataset):
    def __init__(self,
                 ap,
                 meta_data,
                 voice_len=1.6,
                 num_speakers_in_batch=64,
                 num_utter_per_speaker=10,  
                 skip_speakers=False,
                 verbose=False):
        """
        Args:
            ap (TTS.utils.AudioProcessor): audio processor object.
            meta_data (list): list of dataset instances.
            seq_len (int): voice segment length in seconds. 
            verbose (bool): print diagnostic information.
        """
        self.items = meta_data
        self.sample_rate = ap.sample_rate
        self.voice_len = voice_len
        self.seq_len = int(voice_len * self.sample_rate)
        self.num_utter_per_speaker = num_utter_per_speaker
        self.skip_speakers = skip_speakers
        self.ap = ap
        self.verbose = verbose
        self.__parse_items()
        if self.verbose:
            print("\n > DataLoader initialization")
            print(f" | > Number of instances : {len(self.items)}")
            print(f" | > Sequence length: {self.seq_len}")
            print(f" | > Num speakers: {len(self.speakers)}")

    def load_wav(self, filename):
        audio = self.ap.load_wav(filename)
        return audio

    def load_data(self, idx):
        text, wav_file, speaker_name = self.items[idx]
        wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
        mel = self.ap.melspectrogram(wav).astype('float32')
        # sample seq_len

        assert text.size > 0, self.items[idx][1]
        assert wav.size > 0, self.items[idx][1]

        sample = {
            'mel': mel,
            'item_idx': self.items[idx][1],
            'speaker_name': speaker_name
        }
        return sample

    def __parse_items(self):
        """
        Find unique speaker ids and create a dict mapping utterances from speaker id
        """
        speakers = list(set([item[-1] for item in self.items]))
        self.speaker_to_utters = {}
        self.speakers = []
        for speaker in speakers:
            speaker_utters = [item[1] for item in self.items if item[2] == speaker]
            if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers:
                print(f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}.")
            else:
                self.speakers.append(speaker)
                self.speaker_to_utters[speaker] = speaker_utters

    def __len__(self):
        return int(1e+10)

    def __sample_speaker(self):
        speaker = random.sample(self.speakers, 1)[0]
        if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
            utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker)
        else:
            utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker)
        return speaker, utters

    def __sample_speaker_utterances(self, speaker):
        """
        Sample all M utterances for the given speaker.
        """
        feats = []
        labels = []
        for idx in range(self.num_utter_per_speaker):
            # TODO:dummy but works
            while True:
                if len(self.speaker_to_utters[speaker]) > 0:
                    utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
                else:
                    self.speakers.remove(speaker)
                    speaker, _ = self.__sample_speaker()
                    continue
                wav = self.load_wav(utter)
                if wav.shape[0] - self.seq_len > 0:
                    break
                else:
                    self.speaker_to_utters[speaker].remove(utter)

            offset = random.randint(0, wav.shape[0] - self.seq_len)
            mel = self.ap.melspectrogram(wav[offset:offset+self.seq_len])
            feats.append(torch.FloatTensor(mel))
            labels.append(speaker)
        return feats, labels

    def __getitem__(self, idx):
        speaker, _ = self.__sample_speaker()
        return speaker

    def collate_fn(self, batch):
        labels = []
        feats = []
        for speaker in batch:
            feats_, labels_ = self.__sample_speaker_utterances(speaker)
            labels.append(labels_)
            feats.extend(feats_)
        feats = torch.stack(feats)       
        return feats.transpose(1, 2), labels