#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Argument parser for training scripts."""

import argparse
import re
import glob
import os

from TTS.utils.generic_utils import (
    create_experiment_folder, get_git_branch)
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.tensorboard_logger import TensorboardLogger

from TTS.tts.utils.generic_utils import check_config_tts


def parse_arguments(argv):
    """Parse command line arguments of training scripts.

    Parameters
    ----------
    argv : list
        This is a list of input arguments as given by sys.argv

    Returns
    -------
    argparse.Namespace
        Parsed arguments.

    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--continue_path",
        type=str,
        help=("Training output folder to continue training. Used to continue "
              "a training. If it is used, 'config_path' is ignored."),
        default="",
        required="--config_path" not in argv)
    parser.add_argument(
        "--restore_path",
        type=str,
        help="Model file to be restored. Use to finetune a model.",
        default="")
    parser.add_argument(
        "--best_path",
        type=str,
        help="Best model file to be used for extracting best loss.",
        default="")
    parser.add_argument(
        "--config_path",
        type=str,
        help="Path to config file for training.",
        required="--continue_path" not in argv)
    parser.add_argument(
        "--debug",
        type=bool,
        default=False,
        help="Do not verify commit integrity to run training.")
    parser.add_argument(
        "--rank",
        type=int,
        default=0,
        help="DISTRIBUTED: process rank for distributed training.")
    parser.add_argument(
        "--group_id",
        type=str,
        default="",
        help="DISTRIBUTED: process group id.")

    return parser.parse_args()


def get_last_models(path):
    """Get latest checkpoint or/and best model in path.

    It is based on globbing for `*.pth.tar` and the RegEx
    `(checkpoint|best_model)_([0-9]+)`.

    Parameters
    ----------
    path : list
        Path to files to be compared.

    Raises
    ------
    ValueError
        If no checkpoint or best_model files are found.

    Returns
    -------
    last_checkpoint : str
        Last checkpoint filename.

    """
    file_names = glob.glob(os.path.join(path, "*.pth.tar"))
    last_models = {}
    last_model_nums = {}
    for key in ['checkpoint', 'best_model']:
        last_model_num = 0
        last_model = None
        for file_name in file_names:
            try:
                model_num = int(re.search(
                    f"{key}_([0-9]+)", file_name).groups()[0])
                if model_num > last_model_num:
                    last_model_num = model_num
                    last_model = file_name
            except AttributeError:  # if there's no match in the filename
                continue
        last_models[key] = last_model
        last_model_nums[key] = last_model_num

    # check what models were found
    if not last_models:
        raise ValueError(f"No models found in continue path {path}!")
    elif 'checkpoint' not in last_models:  # no checkpoint just best model
        last_models['checkpoint'] = last_models['best_model']
    elif 'best_model' not in last_models:  # no best model
        # this shouldn't happen, but let's handle it just in case
        last_models['best_model'] = None
    # finally check if last best model is more recent than checkpoint
    elif last_model_nums['best_model'] > last_model_nums['checkpoint']:
        last_models['checkpoint'] = last_models['best_model']

    return last_models['checkpoint'], last_models['best_model']


def process_args(args, model_type):
    """Process parsed comand line arguments.

    Parameters
    ----------
    args : argparse.Namespace or dict like
        Parsed input arguments.
    model_type : str
        Model type used to check config parameters and setup the TensorBoard
        logger. One of:
            - tacotron
            - glow_tts
            - speedy_speech
            - gan
            - wavegrad
            - wavernn

    Raises
    ------
    ValueError
        If `model_type` is not one of implemented choices.

    Returns
    -------
    c : TTS.utils.io.AttrDict
        Config paramaters.
    out_path : str
        Path to save models and logging.
    audio_path : str
        Path to save generated test audios.
    c_logger : TTS.utils.console_logger.ConsoleLogger
        Class that does logging to the console.
    tb_logger : TTS.utils.tensorboard.TensorboardLogger
        Class that does the TensorBoard loggind.

    """
    if args.continue_path:
        args.output_path = args.continue_path
        args.config_path = os.path.join(args.continue_path, "config.json")
        args.restore_path, best_model = get_last_models(args.continue_path)
        if not args.best_path:
            args.best_path = best_model

    # setup output paths and read configs
    c = load_config(args.config_path)

    if model_type in "tacotron glow_tts speedy_speech":
        model_class = "TTS"
    elif model_type in "gan wavegrad wavernn":
        model_class = "VOCODER"
    else:
        raise ValueError("model type {model_type} not recognized!")

    if model_class == "TTS":
        check_config_tts(c)
    elif model_class == "VOCODER":
        print("Vocoder config checker not implemented, skipping ...")
    else:
        raise ValueError(f"model type {model_type} not recognized!")

    _ = os.path.dirname(os.path.realpath(__file__))

    if model_type in "tacotron wavegrad wavernn" and c.mixed_precision:
        print("   >  Mixed precision mode is ON")

    out_path = args.continue_path
    if not out_path:
        out_path = create_experiment_folder(c.output_path, c.run_name,
                                            args.debug)

    audio_path = os.path.join(out_path, "test_audios")

    c_logger = ConsoleLogger()

    if args.rank == 0:
        os.makedirs(audio_path, exist_ok=True)
        new_fields = {}
        if args.restore_path:
            new_fields["restore_path"] = args.restore_path
        new_fields["github_branch"] = get_git_branch()
        copy_model_files(c, args.config_path,
                         out_path, new_fields)
        os.chmod(audio_path, 0o775)
        os.chmod(out_path, 0o775)

        log_path = out_path

        tb_logger = TensorboardLogger(log_path, model_name=model_class)

        # write model desc to tensorboard
        tb_logger.tb_add_text("model-description", c["run_description"], 0)

    return c, out_path, audio_path, c_logger, tb_logger