import argparse
import os
import sys
import traceback
import time
import glob
import random

import torch
from torch.utils.data import DataLoader

# from torch.utils.data.distributed import DistributedSampler

from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.radam import RAdam
from TTS.utils.io import copy_config_file, load_config
from TTS.utils.training import setup_torch_training_env
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.generic_utils import (
    KeepAverage,
    count_parameters,
    create_experiment_folder,
    get_git_branch,
    remove_experiment_folder,
    set_init_dict,
)
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.datasets.preprocess import (
    load_wav_data,
    load_wav_feat_data
)
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
from TTS.vocoder.utils.generic_utils import setup_wavernn
from TTS.vocoder.utils.io import save_best_model, save_checkpoint


use_cuda, num_gpus = setup_torch_training_env(True, True)


def setup_loader(ap, is_val=False, verbose=False):
    if is_val and not c.run_eval:
        loader = None
    else:
        dataset = WaveRNNDataset(ap=ap,
                                 items=eval_data if is_val else train_data,
                                 seq_len=c.seq_len,
                                 hop_len=ap.hop_length,
                                 pad=c.padding,
                                 mode=c.mode,
                                 mulaw=c.mulaw,
                                 is_training=not is_val,
                                 verbose=verbose,
                                 )
        # sampler = DistributedSampler(dataset) if num_gpus > 1 else None
        loader = DataLoader(dataset,
                            shuffle=True,
                            collate_fn=dataset.collate,
                            batch_size=c.batch_size,
                            num_workers=c.num_val_loader_workers
                            if is_val
                            else c.num_loader_workers,
                            pin_memory=True,
                            )
    return loader


def format_data(data):
    # setup input data
    x_input = data[0]
    mels = data[1]
    y_coarse = data[2]

    # dispatch data to GPU
    if use_cuda:
        x_input = x_input.cuda(non_blocking=True)
        mels = mels.cuda(non_blocking=True)
        y_coarse = y_coarse.cuda(non_blocking=True)

    return x_input, mels, y_coarse


def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch):
    # create train loader
    data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
    model.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(len(data_loader.dataset) /
                           (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    # train loop
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()
        x_input, mels, y_coarse = format_data(data)
        loader_time = time.time() - end_time
        global_step += 1

        optimizer.zero_grad()

        if c.mixed_precision:
            # mixed precision training
            with torch.cuda.amp.autocast():
                y_hat = model(x_input, mels)
                if isinstance(model.mode, int):
                    y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
                else:
                    y_coarse = y_coarse.float()
                y_coarse = y_coarse.unsqueeze(-1)
                # compute losses
                loss = criterion(y_hat, y_coarse)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            if c.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), c.grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            # full precision training
            y_hat = model(x_input, mels)
            if isinstance(model.mode, int):
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
            else:
                y_coarse = y_coarse.float()
            y_coarse = y_coarse.unsqueeze(-1)
            # compute losses
            loss = criterion(y_hat, y_coarse)
            if loss.item() is None:
                raise RuntimeError(" [!] None loss. Exiting ...")
            loss.backward()
            if c.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), c.grad_clip)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()

        # get the current learning rate
        cur_lr = list(optimizer.param_groups)[0]["lr"]

        step_time = time.time() - start_time
        epoch_time += step_time

        update_train_values = dict()
        loss_dict = dict()
        loss_dict["model_loss"] = loss.item()
        for key, value in loss_dict.items():
            update_train_values["avg_" + key] = value
        update_train_values["avg_loader_time"] = loader_time
        update_train_values["avg_step_time"] = step_time
        keep_avg.update_values(update_train_values)

        # print training stats
        if global_step % c.print_step == 0:
            log_dict = {"step_time": [step_time, 2],
                        "loader_time": [loader_time, 4],
                        "current_lr": cur_lr,
                        }
            c_logger.print_train_step(batch_n_iter,
                                      num_iter,
                                      global_step,
                                      log_dict,
                                      loss_dict,
                                      keep_avg.avg_values,
                                      )

        # plot step stats
        if global_step % 10 == 0:
            iter_stats = {"lr": cur_lr, "step_time": step_time}
            iter_stats.update(loss_dict)
            tb_logger.tb_train_iter_stats(global_step, iter_stats)

        # save checkpoint
        if global_step % c.save_step == 0:
            if c.checkpoint:
                # save model
                save_checkpoint(model,
                                optimizer,
                                scheduler,
                                None,
                                None,
                                None,
                                global_step,
                                epoch,
                                OUT_PATH,
                                model_losses=loss_dict,
                                scaler=scaler.state_dict() if c.mixed_precision else None
                                )

            # synthesize a full voice
            rand_idx = random.randrange(0, len(train_data))
            wav_path = train_data[rand_idx] if not isinstance(
                train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
            wav = ap.load_wav(wav_path)
            ground_mel = ap.melspectrogram(wav)
            sample_wav = model.generate(ground_mel,
                                        c.batched,
                                        c.target_samples,
                                        c.overlap_samples,
                                        use_cuda
                                        )
            predict_mel = ap.melspectrogram(sample_wav)

            # compute spectrograms
            figures = {"train/ground_truth": plot_spectrogram(ground_mel.T),
                       "train/prediction": plot_spectrogram(predict_mel.T)
                       }
            tb_logger.tb_train_figures(global_step, figures)

            # Sample audio
            tb_logger.tb_train_audios(
                global_step, {
                    "train/audio": sample_wav}, c.audio["sample_rate"]
            )
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Training Epoch Stats
    epoch_stats = {"epoch_time": epoch_time}
    epoch_stats.update(keep_avg.avg_values)
    tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
    # TODO: plot model stats
    # if c.tb_model_param_stats:
    # tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step


@torch.no_grad()
def evaluate(model, criterion, ap, global_step, epoch):
    # create train loader
    data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
    model.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    end_time = time.time()
    c_logger.print_eval_start()
    with torch.no_grad():
        for num_iter, data in enumerate(data_loader):
            start_time = time.time()
            # format data
            x_input, mels, y_coarse = format_data(data)
            loader_time = time.time() - end_time
            global_step += 1

            y_hat = model(x_input, mels)
            if isinstance(model.mode, int):
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
            else:
                y_coarse = y_coarse.float()
            y_coarse = y_coarse.unsqueeze(-1)
            loss = criterion(y_hat, y_coarse)
            # Compute avg loss
            # if num_gpus > 1:
            #     loss = reduce_tensor(loss.data, num_gpus)
            loss_dict = dict()
            loss_dict["model_loss"] = loss.item()

            step_time = time.time() - start_time
            epoch_time += step_time

            # update avg stats
            update_eval_values = dict()
            for key, value in loss_dict.items():
                update_eval_values["avg_" + key] = value
            update_eval_values["avg_loader_time"] = loader_time
            update_eval_values["avg_step_time"] = step_time
            keep_avg.update_values(update_eval_values)

            # print eval stats
            if c.print_eval:
                c_logger.print_eval_step(
                    num_iter, loss_dict, keep_avg.avg_values)

    if epoch % c.test_every_epochs == 0 and epoch != 0:
        # synthesize a full voice
        rand_idx = random.randrange(0, len(eval_data))
        wav_path = eval_data[rand_idx] if not isinstance(
            eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
        wav = ap.load_wav(wav_path)
        ground_mel = ap.melspectrogram(wav)
        sample_wav = model.generate(ground_mel,
                                    c.batched,
                                    c.target_samples,
                                    c.overlap_samples,
                                    use_cuda
                                    )
        predict_mel = ap.melspectrogram(sample_wav)

        # Sample audio
        tb_logger.tb_eval_audios(
            global_step, {
                "eval/audio": sample_wav}, c.audio["sample_rate"]
        )

        # compute spectrograms
        figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T),
                   "eval/prediction": plot_spectrogram(predict_mel.T)
                   }
        tb_logger.tb_eval_figures(global_step, figures)

    tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
    return keep_avg.avg_values


# FIXME: move args definition/parsing inside of main?
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global train_data, eval_data

    # setup audio processor
    ap = AudioProcessor(**c.audio)

    # print(f" > Loading wavs from: {c.data_path}")
    # if c.feature_path is not None:
    #     print(f" > Loading features from: {c.feature_path}")
    #     eval_data, train_data = load_wav_feat_data(
    #         c.data_path, c.feature_path, c.eval_split_size
    #     )
    # else:
    #     mel_feat_path = os.path.join(OUT_PATH, "mel")
    #     feat_data = find_feat_files(mel_feat_path)
    #     if feat_data:
    #         print(f" > Loading features from: {mel_feat_path}")
    #         eval_data, train_data = load_wav_feat_data(
    #             c.data_path, mel_feat_path, c.eval_split_size
    #         )
    #     else:
    #         print(" > No feature data found. Preprocessing...")
    #         # preprocessing feature data from given wav files
    #         preprocess_wav_files(OUT_PATH, CONFIG, ap)
    #         eval_data, train_data = load_wav_feat_data(
    #             c.data_path, mel_feat_path, c.eval_split_size
    #         )

    print(f" > Loading wavs from: {c.data_path}")
    if c.feature_path is not None:
        print(f" > Loading features from: {c.feature_path}")
        eval_data, train_data = load_wav_feat_data(
            c.data_path, c.feature_path, c.eval_split_size)
    else:
        eval_data, train_data = load_wav_data(
            c.data_path, c.eval_split_size)
    # setup model
    model_wavernn = setup_wavernn(c)

    # setup amp scaler
    scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None

    # define train functions
    if c.mode == "mold":
        criterion = discretized_mix_logistic_loss
    elif c.mode == "gauss":
        criterion = gaussian_loss
    elif isinstance(c.mode, int):
        criterion = torch.nn.CrossEntropyLoss()

    if use_cuda:
        model_wavernn.cuda()
        if isinstance(c.mode, int):
            criterion.cuda()

    optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0)

    scheduler = None
    if "lr_scheduler" in c:
        scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
        scheduler = scheduler(optimizer, **c.lr_scheduler_params)
    # slow start for the first 5 epochs
    # lr_lambda = lambda epoch: min(epoch / c.warmup_steps, 1)
    # scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # restore any checkpoint
    if args.restore_path:
        checkpoint = torch.load(args.restore_path, map_location="cpu")
        try:
            print(" > Restoring Model...")
            model_wavernn.load_state_dict(checkpoint["model"])
            print(" > Restoring Optimizer...")
            optimizer.load_state_dict(checkpoint["optimizer"])
            if "scheduler" in checkpoint:
                print(" > Restoring Generator LR Scheduler...")
                scheduler.load_state_dict(checkpoint["scheduler"])
                scheduler.optimizer = optimizer
            if "scaler" in checkpoint and c.mixed_precision:
                print(" > Restoring AMP Scaler...")
                scaler.load_state_dict(checkpoint["scaler"])
        except RuntimeError:
            # retore only matching layers.
            print(" > Partial model initialization...")
            model_dict = model_wavernn.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint["model"], c)
            model_wavernn.load_state_dict(model_dict)

        print(" > Model restored from step %d" %
              checkpoint["step"], flush=True)
        args.restore_step = checkpoint["step"]
    else:
        args.restore_step = 0

    # DISTRIBUTED
    # if num_gpus > 1:
    #     model = apply_gradient_allreduce(model)

    num_parameters = count_parameters(model_wavernn)
    print(" > Model has {} parameters".format(num_parameters), flush=True)

    if "best_loss" not in locals():
        best_loss = float("inf")

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        _, global_step = train(model_wavernn, optimizer,
                               criterion, scheduler, scaler, ap, global_step, epoch)
        eval_avg_loss_dict = evaluate(
            model_wavernn, criterion, ap, global_step, epoch)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = eval_avg_loss_dict["avg_model_loss"]
        best_loss = save_best_model(
            target_loss,
            best_loss,
            model_wavernn,
            optimizer,
            scheduler,
            None,
            None,
            None,
            global_step,
            epoch,
            OUT_PATH,
            model_losses=eval_avg_loss_dict,
            scaler=scaler.state_dict() if c.mixed_precision else None
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--continue_path",
        type=str,
        help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
        default="",
        required="--config_path" not in sys.argv,
    )
    parser.add_argument(
        "--restore_path",
        type=str,
        help="Model file to be restored. Use to finetune a model.",
        default="",
    )
    parser.add_argument(
        "--config_path",
        type=str,
        help="Path to config file for training.",
        required="--continue_path" not in sys.argv,
    )
    parser.add_argument(
        "--debug",
        type=bool,
        default=False,
        help="Do not verify commit integrity to run training.",
    )

    # DISTRUBUTED
    parser.add_argument(
        "--rank",
        type=int,
        default=0,
        help="DISTRIBUTED: process rank for distributed training.",
    )
    parser.add_argument(
        "--group_id", type=str, default="", help="DISTRIBUTED: process group id."
    )
    args = parser.parse_args()

    if args.continue_path != "":
        args.output_path = args.continue_path
        args.config_path = os.path.join(args.continue_path, "config.json")
        list_of_files = glob.glob(
            args.continue_path + "/*.pth.tar"
        )  # * means all if need specific format then *.csv
        latest_model_file = max(list_of_files, key=os.path.getctime)
        args.restore_path = latest_model_file
        print(f" > Training continues for {args.restore_path}")

    # setup output paths and read configs
    c = load_config(args.config_path)
    # check_config(c)
    _ = os.path.dirname(os.path.realpath(__file__))

    OUT_PATH = args.continue_path
    if args.continue_path == "":
        OUT_PATH = create_experiment_folder(
            c.output_path, c.run_name, args.debug
        )

    AUDIO_PATH = os.path.join(OUT_PATH, "test_audios")

    c_logger = ConsoleLogger()

    if args.rank == 0:
        os.makedirs(AUDIO_PATH, exist_ok=True)
        new_fields = {}
        if args.restore_path:
            new_fields["restore_path"] = args.restore_path
        new_fields["github_branch"] = get_git_branch()
        copy_config_file(
            args.config_path, os.path.join(OUT_PATH, "config.json"), new_fields
        )
        os.chmod(AUDIO_PATH, 0o775)
        os.chmod(OUT_PATH, 0o775)

        LOG_DIR = OUT_PATH
        tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")

        # write model desc to tensorboard
        tb_logger.tb_add_text("model-description", c["run_description"], 0)

    try:
        main(args)
    except KeyboardInterrupt:
        remove_experiment_folder(OUT_PATH)
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)  # pylint: disable=protected-access
    except Exception:  # pylint: disable=broad-except
        remove_experiment_folder(OUT_PATH)
        traceback.print_exc()
        sys.exit(1)