#!/usr/bin/env python3
"""Extract Mel spectrograms with teacher forcing."""

import argparse
import os

import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from TTS.config import load_config
from TTS.tts.datasets import TTSDataset, load_tts_samples
from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters
from trainer.generic_utils import to_cuda

use_cuda = torch.cuda.is_available()

def set_filename(wav_path, out_path):
    wav_file = os.path.basename(wav_path)
    file_name = wav_file.split(".")[0]
    os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
    os.makedirs(os.path.join(out_path, "wav"), exist_ok=True)
    os.makedirs(os.path.join(out_path, "wav_gt"), exist_ok=True)
    wavq_path = os.path.join(out_path, "quant", file_name)
    wav_gt_path = os.path.join(out_path, "wav_gt", file_name + ".wav")
    wav_path = os.path.join(out_path, "wav", file_name + ".wav")
    return file_name, wavq_path, wav_gt_path, wav_path


def extract_audios(
    data_loader, model, ap, output_path, quantized_wav=False, save_gt_audio=False, use_cuda=True
):
    model.eval()
    export_metadata = []
    for _, batch in tqdm(enumerate(data_loader), total=len(data_loader)):

        batch = model.format_batch(batch)
        batch = model.format_batch_on_device(batch)
        
        if use_cuda:
            for k, v in batch.items():
                    batch[k] = to_cuda(v)

        tokens = batch["tokens"]
        token_lenghts = batch["token_lens"]
        spec = batch["spec"]
        spec_lens = batch["spec_lens"]
        d_vectors = batch["d_vectors"]
        speaker_ids = batch["speaker_ids"]
        language_ids = batch["language_ids"]
        item_idx = batch["audio_files_path"]
        wav_lengths = batch["waveform_lens"]

        outputs = model.inference_with_MAS(
            tokens,
            spec,
            spec_lens,
            aux_input={"x_lengths": token_lenghts, "d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
        )

        model_output = outputs["model_outputs"]
        model_output = model_output.detach().cpu().numpy()

        for idx in range(tokens.shape[0]):
            wav_file_path = item_idx[idx]
            wav_gt = ap.load_wav(wav_file_path)
            
            _, wavq_path, wav_gt_path, wav_path = set_filename(wav_file_path, output_path)

            # quantize and save wav
            if quantized_wav:
                wavq = ap.quantize(wav_gt)
                np.save(wavq_path, wavq)

            # save TTS mel
            wav = model_output[idx][0]
            wav_length = wav_lengths[idx]
            wav = wav[:wav_length]
            ap.save_wav(wav, wav_path)

            if save_gt_audio:
                ap.save_wav(wav_gt, wav_gt_path)


def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global meta_data, speaker_manager

    # Audio processor
    ap = AudioProcessor(**c.audio)

    # load data instances
    meta_data_train, meta_data_eval = load_tts_samples(
        c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
    )

    # use eval and training partitions
    meta_data = meta_data_train + meta_data_eval

    # setup model
    model = setup_model(c, meta_data)

    # restore model
    model.load_checkpoint(c, args.checkpoint_path, eval=True)

    if use_cuda:
        model.cuda()

    num_params = count_parameters(model)
    print("\n > Model has {} parameters".format(num_params), flush=True)

    own_loader = model.get_data_loader(config=model.config,
        assets={},
        is_eval=False,
        samples=meta_data,
        verbose=True,
        num_gpus=1,
    )

    extract_audios(
        own_loader,
        model,
        ap,
        args.output_path,
        quantized_wav=args.quantized,
        save_gt_audio=args.save_gt_audio,
        use_cuda=use_cuda,
    )


if __name__ == "__main__":
    # python3 TTS/bin/extract_tts_audio.py --config_path /raid/edresson/dev/Checkpoints/YourTTS/new_vctk_trimmed_silence/upsampling/YourTTS_22khz--\>44khz_vocoder_approach_frozen/YourTTS_22khz--\>44khz_vocoder_approach_frozen-April-02-2022_08+23PM-a5f5ebae/config.json --checkpoint_path /raid/edresson/dev/Checkpoints/YourTTS/new_vctk_trimmed_silence/upsampling/YourTTS_22khz--\>44khz_vocoder_approach_frozen/YourTTS_22khz--\>44khz_vocoder_approach_frozen-April-02-2022_08+23PM-a5f5ebae/checkpoint_1600000.pth --output_path ../Test_extract_audio_script/
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
    parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)
    parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
    parser.add_argument("--save_gt_audio", default=False, action="store_true", help="Save audio files")
    parser.add_argument("--quantized", action="store_true", help="Save quantized audio files")
    parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
    args = parser.parse_args()

    c = load_config(args.config_path)
    c.audio.trim_silence = False
    c.batch_size = 4
    main(args)