import argparse
import glob
import os
import sys
import time
import traceback

import torch
from torch.utils.data import DataLoader

from inspect import signature

from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
                                     create_experiment_folder, get_git_branch,
                                     remove_experiment_folder, set_init_dict)
from TTS.utils.io import copy_config_file, load_config
from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import setup_torch_training_env
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
# from distribute import (DistributedSampler, apply_gradient_allreduce,
#                         init_distributed, reduce_tensor)
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
from TTS.vocoder.utils.io import save_checkpoint, save_best_model
from TTS.vocoder.utils.console_logger import ConsoleLogger
from TTS.vocoder.utils.generic_utils import (check_config, plot_results,
                                             setup_discriminator,
                                             setup_generator)


use_cuda, num_gpus = setup_torch_training_env(True, True)


def setup_loader(ap, is_val=False, verbose=False):
    if is_val and not c.run_eval:
        loader = None
    else:
        dataset = GANDataset(ap=ap,
                             items=eval_data if is_val else train_data,
                             seq_len=c.seq_len,
                             hop_len=ap.hop_length,
                             pad_short=c.pad_short,
                             conv_pad=c.conv_pad,
                             is_training=not is_val,
                             return_segments=not is_val,
                             use_noise_augment=c.use_noise_augment,
                             use_cache=c.use_cache,
                             verbose=verbose)
        dataset.shuffle_mapping()
        # sampler = DistributedSampler(dataset) if num_gpus > 1 else None
        loader = DataLoader(dataset,
                            batch_size=1 if is_val else c.batch_size,
                            shuffle=True,
                            drop_last=False,
                            sampler=None,
                            num_workers=c.num_val_loader_workers
                            if is_val else c.num_loader_workers,
                            pin_memory=False)
    return loader


def format_data(data):
    if isinstance(data[0], list):
        # setup input data
        c_G, x_G = data[0]
        c_D, x_D = data[1]

        # dispatch data to GPU
        if use_cuda:
            c_G = c_G.cuda(non_blocking=True)
            x_G = x_G.cuda(non_blocking=True)
            c_D = c_D.cuda(non_blocking=True)
            x_D = x_D.cuda(non_blocking=True)

        return c_G, x_G, c_D, x_D

    # return a whole audio segment
    co, x = data
    if use_cuda:
        co = co.cuda(non_blocking=True)
        x = x.cuda(non_blocking=True)
    return co, x, None, None


def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
          scheduler_G, scheduler_D, ap, global_step, epoch):
    data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
    model_G.train()
    model_D.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        c_G, y_G, c_D, y_D = format_data(data)
        loader_time = time.time() - end_time

        global_step += 1

        ##############################
        # GENERATOR
        ##############################

        # generator pass
        y_hat = model_G(c_G)
        y_hat_sub = None
        y_G_sub = None
        y_hat_vis = y_hat  # for visualization

        # PQMF formatting
        if y_hat.shape[1] > 1:
            y_hat_sub = y_hat
            y_hat = model_G.pqmf_synthesis(y_hat)
            y_hat_vis = y_hat
            y_G_sub = model_G.pqmf_analysis(y_G)

        if global_step > c.steps_to_start_discriminator:

            # run D with or without cond. features
            if len(signature(model_D.forward).parameters) == 2:
                D_out_fake = model_D(y_hat, c_G)
            else:
                D_out_fake = model_D(y_hat)
            D_out_real = None

            if c.use_feat_match_loss:
                with torch.no_grad():
                    D_out_real = model_D(y_G)

            # format D outputs
            if isinstance(D_out_fake, tuple):
                scores_fake, feats_fake = D_out_fake
                if D_out_real is None:
                    feats_real = None
                else:
                    _, feats_real = D_out_real
            else:
                scores_fake = D_out_fake
        else:
            scores_fake, feats_fake, feats_real = None, None, None

        # compute losses
        loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
                                  feats_real, y_hat_sub, y_G_sub)
        loss_G = loss_G_dict['G_loss']

        # optimizer generator
        optimizer_G.zero_grad()
        loss_G.backward()
        if c.gen_clip_grad > 0:
            torch.nn.utils.clip_grad_norm_(model_G.parameters(),
                                           c.gen_clip_grad)
        optimizer_G.step()
        if scheduler_G is not None:
            scheduler_G.step()

        loss_dict = dict()
        for key, value in loss_G_dict.items():
            if isinstance(value, int):
                loss_dict[key] = value
            else:
                loss_dict[key] = value.item()

        ##############################
        # DISCRIMINATOR
        ##############################
        if global_step >= c.steps_to_start_discriminator:
            # discriminator pass
            with torch.no_grad():
                y_hat = model_G(c_D)

            # PQMF formatting
            if y_hat.shape[1] > 1:
                y_hat = model_G.pqmf_synthesis(y_hat)

            # run D with or without cond. features
            if len(signature(model_D.forward).parameters) == 2:
                D_out_fake = model_D(y_hat.detach(), c_D)
                D_out_real = model_D(y_D, c_D)
            else:
                D_out_fake = model_D(y_hat.detach())
                D_out_real = model_D(y_D)

            # format D outputs
            if isinstance(D_out_fake, tuple):
                scores_fake, feats_fake = D_out_fake
                if D_out_real is None:
                    scores_real, feats_real = None, None
                else:
                    scores_real, feats_real = D_out_real
            else:
                scores_fake = D_out_fake
                scores_real = D_out_real

            # compute losses
            loss_D_dict = criterion_D(scores_fake, scores_real)
            loss_D = loss_D_dict['D_loss']

            # optimizer discriminator
            optimizer_D.zero_grad()
            loss_D.backward()
            if c.disc_clip_grad > 0:
                torch.nn.utils.clip_grad_norm_(model_D.parameters(),
                                               c.disc_clip_grad)
            optimizer_D.step()
            if scheduler_D is not None:
                scheduler_D.step()

            for key, value in loss_D_dict.items():
                if isinstance(value, (int, float)):
                    loss_dict[key] = value
                else:
                    loss_dict[key] = value.item()

        step_time = time.time() - start_time
        epoch_time += step_time

        # get current learning rates
        current_lr_G = list(optimizer_G.param_groups)[0]['lr']
        current_lr_D = list(optimizer_D.param_groups)[0]['lr']

        # update avg stats
        update_train_values = dict()
        for key, value in loss_dict.items():
            update_train_values['avg_' + key] = value
        update_train_values['avg_loader_time'] = loader_time
        update_train_values['avg_step_time'] = step_time
        keep_avg.update_values(update_train_values)

        # print training stats
        if global_step % c.print_step == 0:
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      step_time, loader_time, current_lr_G,
                                      current_lr_D, loss_dict,
                                      keep_avg.avg_values)

        # plot step stats
        if global_step % 10 == 0:
            iter_stats = {
                "lr_G": current_lr_G,
                "lr_D": current_lr_D,
                "step_time": step_time
            }
            iter_stats.update(loss_dict)
            tb_logger.tb_train_iter_stats(global_step, iter_stats)

        # save checkpoint
        if global_step % c.save_step == 0:
            if c.checkpoint:
                # save model
                save_checkpoint(model_G,
                                optimizer_G,
                                scheduler_G,
                                model_D,
                                optimizer_D,
                                scheduler_D,
                                global_step,
                                epoch,
                                OUT_PATH,
                                model_losses=loss_dict)

            # compute spectrograms
            figures = plot_results(y_hat_vis, y_G, ap, global_step,
                                   'train')
            tb_logger.tb_train_figures(global_step, figures)

            # Sample audio
            sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
            tb_logger.tb_train_audios(global_step,
                                      {'train/audio': sample_voice},
                                      c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Training Epoch Stats
    epoch_stats = {"epoch_time": epoch_time}
    epoch_stats.update(keep_avg.avg_values)
    tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
    # TODO: plot model stats
    # if c.tb_model_param_stats:
    # tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step


@torch.no_grad()
def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch):
    data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
    model_G.eval()
    model_D.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    end_time = time.time()
    c_logger.print_eval_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        c_G, y_G, _, _ = format_data(data)
        loader_time = time.time() - end_time

        global_step += 1

        ##############################
        # GENERATOR
        ##############################

        # generator pass
        y_hat = model_G(c_G)
        y_hat_sub = None
        y_G_sub = None

        # PQMF formatting
        if y_hat.shape[1] > 1:
            y_hat_sub = y_hat
            y_hat = model_G.pqmf_synthesis(y_hat)
            y_G_sub = model_G.pqmf_analysis(y_G)


        if global_step > c.steps_to_start_discriminator:

            if len(signature(model_D.forward).parameters) == 2:
                D_out_fake = model_D(y_hat, c_G)
            else:
                D_out_fake = model_D(y_hat)
            D_out_real = None

            if c.use_feat_match_loss:
                with torch.no_grad():
                    D_out_real = model_D(y_G)

            # format D outputs
            if isinstance(D_out_fake, tuple):
                scores_fake, feats_fake = D_out_fake
                if D_out_real is None:
                    feats_real = None
                else:
                    _, feats_real = D_out_real
            else:
                scores_fake = D_out_fake
        else:
            scores_fake, feats_fake, feats_real = None, None, None

        # compute losses
        loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
                                  feats_real, y_hat_sub, y_G_sub)

        loss_dict = dict()
        for key, value in loss_G_dict.items():
            if isinstance(value, (int, float)):
                loss_dict[key] = value
            else:
                loss_dict[key] = value.item()

        ##############################
        # DISCRIMINATOR
        ##############################

        if global_step >= c.steps_to_start_discriminator:
            # discriminator pass
            with torch.no_grad():
                y_hat = model_G(c_G)

            # PQMF formatting
            if y_hat.shape[1] > 1:
                y_hat = model_G.pqmf_synthesis(y_hat)

            # run D with or without cond. features
            if len(signature(model_D.forward).parameters) == 2:
                D_out_fake = model_D(y_hat.detach(), c_G)
                D_out_real = model_D(y_G, c_G)
            else:
                D_out_fake = model_D(y_hat.detach())
                D_out_real = model_D(y_G)

            # format D outputs
            if isinstance(D_out_fake, tuple):
                scores_fake, feats_fake = D_out_fake
                if D_out_real is None:
                    scores_real, feats_real = None, None
                else:
                    scores_real, feats_real = D_out_real
            else:
                scores_fake = D_out_fake
                scores_real = D_out_real

            # compute losses
            loss_D_dict = criterion_D(scores_fake, scores_real)

            for key, value in loss_D_dict.items():
                if isinstance(value, (int, float)):
                    loss_dict[key] = value
                else:
                    loss_dict[key] = value.item()


        step_time = time.time() - start_time
        epoch_time += step_time

        # update avg stats
        update_eval_values = dict()
        for key, value in loss_dict.items():
            update_eval_values['avg_' + key] = value
        update_eval_values['avg_loader_time'] = loader_time
        update_eval_values['avg_step_time'] = step_time
        keep_avg.update_values(update_eval_values)

        # print eval stats
        if c.print_eval:
            c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)

    # compute spectrograms
    figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
    tb_logger.tb_eval_figures(global_step, figures)

    # Sample audio
    sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
    tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
                             c.audio["sample_rate"])

    # synthesize a full voice
    data_loader.return_segments = False

    tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)

    return keep_avg.avg_values


# FIXME: move args definition/parsing inside of main?
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global train_data, eval_data
    print(f" > Loading wavs from: {c.data_path}")
    if c.feature_path is not None:
        print(f" > Loading features from: {c.feature_path}")
        eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
    else:
        eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)

    # setup audio processor
    ap = AudioProcessor(**c.audio)

    # DISTRUBUTED
    # if num_gpus > 1:
    # init_distributed(args.rank, num_gpus, args.group_id,
    #  c.distributed["backend"], c.distributed["url"])

    # setup models
    model_gen = setup_generator(c)
    model_disc = setup_discriminator(c)

    # setup optimizers
    optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0)
    optimizer_disc = RAdam(model_disc.parameters(),
                           lr=c.lr_disc,
                           weight_decay=0)

    # schedulers
    scheduler_gen = None
    scheduler_disc = None
    if 'lr_scheduler_gen' in c:
        scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
        scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params)
    if 'lr_scheduler_disc' in c:
        scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
        scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params)

    # setup criterion
    criterion_gen = GeneratorLoss(c)
    criterion_disc = DiscriminatorLoss(c)

    if args.restore_path:
        checkpoint = torch.load(args.restore_path, map_location='cpu')
        try:
            print(" > Restoring Generator Model...")
            model_gen.load_state_dict(checkpoint['model'])
            print(" > Restoring Generator Optimizer...")
            optimizer_gen.load_state_dict(checkpoint['optimizer'])
            print(" > Restoring Discriminator Model...")
            model_disc.load_state_dict(checkpoint['model_disc'])
            print(" > Restoring Discriminator Optimizer...")
            optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
            if 'scheduler' in checkpoint:
                print(" > Restoring Generator LR Scheduler...")
                scheduler_gen.load_state_dict(checkpoint['scheduler'])
                # NOTE: Not sure if necessary
                scheduler_gen.optimizer = optimizer_gen
            if 'scheduler_disc' in checkpoint:
                print(" > Restoring Discriminator LR Scheduler...")
                scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
                scheduler_disc.optimizer = optimizer_disc
        except RuntimeError:
            # retore only matching layers.
            print(" > Partial model initialization...")
            model_dict = model_gen.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model'], c)
            model_gen.load_state_dict(model_dict)

            model_dict = model_disc.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c)
            model_disc.load_state_dict(model_dict)
            del model_dict

        # reset lr if not countinuining training.
        for group in optimizer_gen.param_groups:
            group['lr'] = c.lr_gen

        for group in optimizer_disc.param_groups:
            group['lr'] = c.lr_disc

        print(" > Model restored from step %d" % checkpoint['step'],
              flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model_gen.cuda()
        criterion_gen.cuda()
        model_disc.cuda()
        criterion_disc.cuda()

    # DISTRUBUTED
    # if num_gpus > 1:
    #     model = apply_gradient_allreduce(model)

    num_params = count_parameters(model_gen)
    print(" > Generator has {} parameters".format(num_params), flush=True)
    num_params = count_parameters(model_disc)
    print(" > Discriminator has {} parameters".format(num_params), flush=True)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        _, global_step = train(model_gen, criterion_gen, optimizer_gen,
                               model_disc, criterion_disc, optimizer_disc,
                               scheduler_gen, scheduler_disc, ap, global_step,
                               epoch)
        eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap,
                                      global_step, epoch)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = eval_avg_loss_dict[c.target_loss]
        best_loss = save_best_model(target_loss,
                                    best_loss,
                                    model_gen,
                                    optimizer_gen,
                                    scheduler_gen,
                                    model_disc,
                                    optimizer_disc,
                                    scheduler_disc,
                                    global_step,
                                    epoch,
                                    OUT_PATH,
                                    model_losses=eval_avg_loss_dict)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--continue_path',
        type=str,
        help=
        'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
        default='',
        required='--config_path' not in sys.argv)
    parser.add_argument(
        '--restore_path',
        type=str,
        help='Model file to be restored. Use to finetune a model.',
        default='')
    parser.add_argument('--config_path',
                        type=str,
                        help='Path to config file for training.',
                        required='--continue_path' not in sys.argv)
    parser.add_argument('--debug',
                        type=bool,
                        default=False,
                        help='Do not verify commit integrity to run training.')

    # DISTRUBUTED
    parser.add_argument(
        '--rank',
        type=int,
        default=0,
        help='DISTRIBUTED: process rank for distributed training.')
    parser.add_argument('--group_id',
                        type=str,
                        default="",
                        help='DISTRIBUTED: process group id.')
    args = parser.parse_args()

    if args.continue_path != '':
        args.output_path = args.continue_path
        args.config_path = os.path.join(args.continue_path, 'config.json')
        list_of_files = glob.glob(
            args.continue_path +
            "/*.pth.tar")  # * means all if need specific format then *.csv
        latest_model_file = max(list_of_files, key=os.path.getctime)
        args.restore_path = latest_model_file
        print(f" > Training continues for {args.restore_path}")

    # setup output paths and read configs
    c = load_config(args.config_path)
    check_config(c)
    _ = os.path.dirname(os.path.realpath(__file__))

    OUT_PATH = args.continue_path
    if args.continue_path == '':
        OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
                                            args.debug)

    AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')

    c_logger = ConsoleLogger()

    if args.rank == 0:
        os.makedirs(AUDIO_PATH, exist_ok=True)
        new_fields = {}
        if args.restore_path:
            new_fields["restore_path"] = args.restore_path
        new_fields["github_branch"] = get_git_branch()
        copy_config_file(args.config_path,
                         os.path.join(OUT_PATH, 'config.json'), new_fields)
        os.chmod(AUDIO_PATH, 0o775)
        os.chmod(OUT_PATH, 0o775)

        LOG_DIR = OUT_PATH
        tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')

        # write model desc to tensorboard
        tb_logger.tb_add_text('model-description', c['run_description'], 0)

    try:
        main(args)
    except KeyboardInterrupt:
        remove_experiment_folder(OUT_PATH)
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)  # pylint: disable=protected-access
    except Exception:  # pylint: disable=broad-except
        remove_experiment_folder(OUT_PATH)
        traceback.print_exc()
        sys.exit(1)