"""Compute attention masks from pre-trained Tacotron or Tacotron2 models.
Sample run on LJSpeech dataset.

 >>>> CUDA_VISIBLE_DEVICES="0" python TTS/bin/compute_attention_masks.py \
     --model_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_100000.pth.tar \
     --config_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json --dataset ljspeech \
     --dataset_metafile /home/erogol/Data/LJSpeech-1.1/metadata.csv \
     --data_path /home/erogol/Data/LJSpeech-1.1/ \
     --batch_size 16 \
     --use_cuda true

"""


import argparse
import importlib
import os

import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import load_checkpoint
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Extract attention masks from trained Tacotron models.')
    parser.add_argument('--model_path',
                        type=str,
                        help='Path to Tacotron or Tacotron2 model file ')
    parser.add_argument(
        '--config_path',
        type=str,
        required=True,
        help='Path to config file for training.',
    )
    parser.add_argument('--dataset',
                        type=str,
                        default='',
                        help='Dataset from TTS.tts.dataset.preprocess.')

    parser.add_argument(
        '--dataset_metafile',
        type=str,
        default='',
        help='Dataset metafile inclusing file paths with transcripts.')
    parser.add_argument(
        '--data_path',
        type=str,
        default='',
        help='Defines the data path. It overwrites config.json.')
    parser.add_argument('--output_path',
                        type=str,
                        help='path for training outputs.',
                        default='')
    parser.add_argument('--output_folder',
                        type=str,
                        default='',
                        help='folder name for training outputs.')

    parser.add_argument('--use_cuda',
                        type=bool,
                        default=False,
                        help="enable/disable cuda.")

    parser.add_argument(
        '--batch_size',
        default=16,
        type=int,
        help='Batch size for the model. Use batch_size=1 if you have no CUDA.')
    args = parser.parse_args()

    C = load_config(args.config_path)
    ap = AudioProcessor(**C.audio)

    # if the vocabulary was passed, replace the default
    if 'characters' in C.keys():
        symbols, phonemes = make_symbols(**C.characters)

    # load the model
    num_chars = len(phonemes) if C.use_phonemes else len(symbols)
    # TODO: handle multi-speaker
    model = setup_model(num_chars, num_speakers=0, c=C)
    model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
    model.eval()

    # data loader
    preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')
    preprocessor = getattr(preprocessor, args.dataset)
    meta_data = preprocessor(args.data_path, args.dataset_metafile)
    dataset = MyDataset(model.decoder.r,
                        C.text_cleaner,
                        compute_linear_spec=False,
                        ap=ap,
                        meta_data=meta_data,
                        tp=C.characters if 'characters' in C.keys() else None,
                        add_blank=C['add_blank'] if 'add_blank' in C.keys() else False,
                        use_phonemes=C.use_phonemes,
                        phoneme_cache_path=C.phoneme_cache_path,
                        phoneme_language=C.phoneme_language,
                        enable_eos_bos=C.enable_eos_bos_chars)

    dataset.sort_items()
    loader = DataLoader(dataset,
                        batch_size=args.batch_size,
                        num_workers=4,
                        collate_fn=dataset.collate_fn,
                        shuffle=False,
                        drop_last=False)

    # compute attentions
    file_paths = []
    with torch.no_grad():
        for data in tqdm(loader):
            # setup input data
            text_input = data[0]
            text_lengths = data[1]
            linear_input = data[3]
            mel_input = data[4]
            mel_lengths = data[5]
            stop_targets = data[6]
            item_idxs = data[7]

            # dispatch data to GPU
            if args.use_cuda:
                text_input = text_input.cuda()
                text_lengths = text_lengths.cuda()
                mel_input = mel_input.cuda()
                mel_lengths = mel_lengths.cuda()

            mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(
                text_input, text_lengths, mel_input)

            alignments = alignments.detach()
            for idx, alignment in enumerate(alignments):
                item_idx = item_idxs[idx]
                # interpolate if r > 1
                alignment = torch.nn.functional.interpolate(
                    alignment.transpose(0, 1).unsqueeze(0),
                    size=None,
                    scale_factor=model.decoder.r,
                    mode='nearest',
                    align_corners=None,
                    recompute_scale_factor=None).squeeze(0).transpose(0, 1)

                # remove paddings
                alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy()

                # set file paths
                wav_file_name = os.path.basename(item_idx)
                align_file_name = os.path.splitext(wav_file_name)[0] + '.npy'
                file_path = item_idx.replace(wav_file_name, align_file_name)
                # save output
                file_paths.append([item_idx, file_path])
                np.save(file_path, alignment)

        # ourpur metafile
        metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")

        with open(metafile, "w") as f:
            for p in file_paths:
                f.write(f"{p[0]}|{p[1]}\n")
        print(f" >> Metafile created: {metafile}")