# -*- coding: utf-8 -*-

import importlib
import multiprocessing
import os
import platform
import re
import sys
import time
import traceback
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union
from urllib.parse import urlparse

import fsspec
import torch
import torch.distributed as dist
from coqpit import Coqpit
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader

from TTS.config import load_config, register_config
from TTS.tts.datasets import load_meta_data
from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.callbacks import TrainerCallback
from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import (
    KeepAverage,
    count_parameters,
    get_experiment_folder_path,
    get_git_branch,
    remove_experiment_folder,
    set_init_dict,
    to_cuda,
)
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model as setup_vocoder_model

multiprocessing.set_start_method("fork")

if platform.system() != "Windows":
    # https://github.com/pytorch/pytorch/issues/973
    import resource

    rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
    resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))


if is_apex_available():
    from apex import amp


@dataclass
class TrainingArgs(Coqpit):
    """Trainer arguments to be defined externally. It helps integrating the `Trainer` with the higher level APIs and
    set the values for distributed training."""

    continue_path: str = field(
        default="",
        metadata={
            "help": "Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder."
        },
    )
    restore_path: str = field(
        default="",
        metadata={
            "help": "Path to a model checkpoit. Restore the model with the given checkpoint and start a new training."
        },
    )
    best_path: str = field(
        default="",
        metadata={
            "help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used"
        },
    )
    config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
    rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
    group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
    use_ddp: bool = field(
        default=False,
        metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."},
    )


class Trainer:
    def __init__(
        self,
        args: Union[Coqpit, Namespace],
        config: Coqpit,
        output_path: str,
        c_logger: ConsoleLogger = None,
        dashboard_logger: Union[TensorboardLogger, WandbLogger] = None,
        model: nn.Module = None,
        cudnn_benchmark: bool = False,
    ) -> None:
        """Simple yet powerful 🐸💬 TTS trainer for PyTorch. It can train all the available `tts` and `vocoder` models
        or easily be customized.

        Notes:

            Supports Automatic Mixed Precision training. If `Apex` is availabe, it automatically picks that, else
            it uses PyTorch's native `amp` module. `Apex` may provide more stable training in some cases.

        Args:

            args (Union[Coqpit, Namespace]): Training arguments parsed either from console by `argparse` or `TrainingArgs`
                config object.

            config (Coqpit): Model config object. It includes all the values necessary for initializing, training, evaluating
                and testing the model.

            output_path (str): Path to the output training folder. All the files are saved under thi path.

            c_logger (ConsoleLogger, optional): Console logger for printing training status. If not provided, the default
                console logger is used. Defaults to None.

            dashboard_logger Union[TensorboardLogger, WandbLogger]: Dashboard logger. If not provided, the tensorboard logger is used.
                Defaults to None.

            model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
                initializes a model from the provided config. Defaults to None.

            cudnn_benchmark (bool): enable/disable PyTorch cudnn benchmarking. It is better to disable if the model input
                length is changing batch to batch along the training.

        Examples:

            Running trainer on a model.

            >>> args = TrainingArgs(...)
            >>> config = HifiganConfig(...)
            >>> model = GANModel(config)
            >>> trainer = Trainer(args, config, output_path, model=model)
            >>> trainer.fit()

            Running trainer on a config.

            >>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,)
            >>> args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
            >>> trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
            >>> trainer.fit()

        TODO:
            - Wrap model for not calling .module in DDP.
            - Accumulate gradients b/w batches.
            - Deepspeed integration
            - Profiler integration.
            - Overfitting to a batch.
            - TPU training
        """

        if config is None:
            # parse config from console arguments
            config, output_path, _, c_logger, dashboard_logger = process_args(args)

        self.args = args
        self.config = config
        self.output_path = output_path
        self.config.output_log_path = output_path

        # setup logging
        log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
        self._setup_logger_config(log_file)

        # set and initialize Pytorch runtime
        self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)

        # init loggers
        self.c_logger = ConsoleLogger() if c_logger is None else c_logger
        self.dashboard_logger = dashboard_logger

        # only allow dashboard logging for the main process in DDP mode
        if self.dashboard_logger is None and args.rank == 0:
            self.dashboard_logger = init_dashboard_logger(config)

        if not self.config.log_model_step:
            self.config.log_model_step = self.config.save_step

        self.total_steps_done = 0
        self.epochs_done = 0
        self.restore_step = 0
        self.best_loss = float("inf")
        self.train_loader = None
        self.eval_loader = None

        self.keep_avg_train = None
        self.keep_avg_eval = None

        self.use_apex = self._is_apex_available()
        self.use_amp_scaler = self.config.mixed_precision and self.use_cuda

        # init audio processor
        self.ap = AudioProcessor(**self.config.audio.to_dict())

        # load data samples
        # TODO: refactor this
        if "datasets" in self.config:
            # load data for `tts` models
            self.data_train, self.data_eval = load_meta_data(self.config.datasets)
        elif self.config.feature_path is not None:
            # load pre-comnputed features for `vocoder`models
            print(f" > Loading features from: {self.config.feature_path}")
            self.data_eval, self.data_train = load_wav_feat_data(
                self.config.data_path, self.config.feature_path, self.config.eval_split_size
            )
        else:
            # load data for `vocoder`models
            self.data_eval, self.data_train = load_wav_data(self.config.data_path, self.config.eval_split_size)

        # init TTS model
        if model is not None:
            self.model = model
        else:
            self.model = self.get_model(self.config)

        # init multispeaker settings of the model
        if hasattr(self.model, "init_multispeaker"):
            self.model.init_multispeaker(self.config, self.data_train + self.data_eval)

        # setup criterion
        self.criterion = self.get_criterion(self.model)

        # DISTRUBUTED
        if self.num_gpus > 1:
            init_distributed(
                args.rank,
                self.num_gpus,
                args.group_id,
                self.config.distributed_backend,
                self.config.distributed_url,
            )

        if self.use_cuda:
            self.model.cuda()
            if isinstance(self.criterion, list):
                self.criterion = [x.cuda() for x in self.criterion]
            else:
                self.criterion.cuda()

        # setup optimizer
        self.optimizer = self.get_optimizer(self.model, self.config)

        # callback
        self.callbacks = TrainerCallback(self)
        self.callbacks.on_init_start()

        # init AMP
        if self.use_amp_scaler:
            if self.use_apex:
                self.scaler = None
                self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1")
            # if isinstance(self.optimizer, list):
            #     self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer)
            # else:
            self.scaler = torch.cuda.amp.GradScaler()
        else:
            self.scaler = None

        if self.args.restore_path:
            self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model(
                self.config, args.restore_path, self.model, self.optimizer, self.scaler
            )

        # setup scheduler
        self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer)

        if self.scheduler is not None:
            if self.args.continue_path:
                if isinstance(self.scheduler, list):
                    for scheduler in self.scheduler:
                        if scheduler is not None:
                            scheduler.last_epoch = self.restore_step
                else:
                    self.scheduler.last_epoch = self.restore_step

        # DISTRUBUTED
        if self.num_gpus > 1:
            self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)

        # count model size
        num_params = count_parameters(self.model)
        print("\n > Model has {} parameters".format(num_params))

        self.callbacks.on_init_end()

    @staticmethod
    def get_model(config: Coqpit) -> nn.Module:
        """Initialize model from config.

        Args:
            config (Coqpit): Model config.

        Returns:
            nn.Module: initialized model.
        """
        try:
            model = setup_vocoder_model(config)
        except ModuleNotFoundError:
            model = setup_tts_model(config)
        return model

    def restore_model(
        self,
        config: Coqpit,
        restore_path: str,
        model: nn.Module,
        optimizer: torch.optim.Optimizer,
        scaler: torch.cuda.amp.GradScaler = None,
    ) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]:
        """Restore training from an old run. It restores model, optimizer, AMP scaler and training stats.

        Args:
            config (Coqpit): Model config.
            restore_path (str): Path to the restored training run.
            model (nn.Module): Model to restored.
            optimizer (torch.optim.Optimizer): Optimizer to restore.
            scaler (torch.cuda.amp.GradScaler, optional): AMP scaler to restore. Defaults to None.

        Returns:
            Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: [description]
        """

        def _restore_list_objs(states, obj):
            if isinstance(obj, list):
                for idx, state in enumerate(states):
                    obj[idx].load_state_dict(state)
            else:
                obj.load_state_dict(states)
            return obj

        print(" > Restoring from %s ..." % os.path.basename(restore_path))
        checkpoint = load_fsspec(restore_path, map_location="cpu")
        try:
            print(" > Restoring Model...")
            model.load_state_dict(checkpoint["model"])
            print(" > Restoring Optimizer...")
            optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer)
            if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
                print(" > Restoring Scaler...")
                scaler = _restore_list_objs(checkpoint["scaler"], scaler)
        except (KeyError, RuntimeError):
            print(" > Partial model initialization...")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint["model"], config)
            model.load_state_dict(model_dict)
            del model_dict

        if isinstance(self.optimizer, list):
            for idx, optim in enumerate(optimizer):
                for group in optim.param_groups:
                    group["lr"] = self.get_lr(model, config)[idx]
        else:
            for group in optimizer.param_groups:
                group["lr"] = self.get_lr(model, config)
        print(
            " > Model restored from step %d" % checkpoint["step"],
        )
        restore_step = checkpoint["step"]
        torch.cuda.empty_cache()
        return model, optimizer, scaler, restore_step

    def _get_loader(
        self,
        model: nn.Module,
        config: Coqpit,
        ap: AudioProcessor,
        is_eval: bool,
        data_items: List,
        verbose: bool,
        num_gpus: int,
    ) -> DataLoader:
        if num_gpus > 1:
            if hasattr(model.module, "get_data_loader"):
                loader = model.module.get_data_loader(
                    config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank
                )
        else:
            if hasattr(model, "get_data_loader"):
                loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
        return loader

    def get_train_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader:
        """Initialize and return a training data loader.

        Args:
            ap (AudioProcessor): Audio processor.
            data_items (List): Data samples used for training.
            verbose (bool): enable/disable printing loader stats at initialization.

        Returns:
            DataLoader: Initialized training data loader.
        """
        return self._get_loader(self.model, self.config, ap, False, data_items, verbose, self.num_gpus)

    def get_eval_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader:
        return self._get_loader(self.model, self.config, ap, True, data_items, verbose, self.num_gpus)

    def format_batch(self, batch: List) -> Dict:
        """Format the dataloader output and return a batch.

        Args:
            batch (List): Batch returned by the dataloader.

        Returns:
            Dict: Formatted batch.
        """
        if self.num_gpus > 1:
            batch = self.model.module.format_batch(batch)
        else:
            batch = self.model.format_batch(batch)
        if self.use_cuda:
            for k, v in batch.items():
                batch[k] = to_cuda(v)
        return batch

    @staticmethod
    def master_params(optimizer: torch.optim.Optimizer):
        """Generator over parameters owned by the optimizer.

        Used to select parameters used by the optimizer for gradient clipping.

        Args:
            optimizer: Target optimizer.
        """
        for group in optimizer.param_groups:
            for p in group["params"]:
                yield p

    @staticmethod
    def _model_train_step(
        batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None
    ) -> Tuple[Dict, Dict]:
        """
        Perform a trainig forward step. Compute model outputs and losses.

        Args:
            batch (Dict): [description]
            model (nn.Module): [description]
            criterion (nn.Module): [description]
            optimizer_idx (int, optional): [description]. Defaults to None.

        Returns:
            Tuple[Dict, Dict]: [description]
        """
        input_args = [batch, criterion]
        if optimizer_idx is not None:
            input_args.append(optimizer_idx)
        # unwrap model in DDP training
        if hasattr(model, "module"):
            return model.module.train_step(*input_args)
        return model.train_step(*input_args)

    def _optimize(
        self,
        batch: Dict,
        model: nn.Module,
        optimizer: Union[torch.optim.Optimizer, List],
        scaler: "AMPScaler",
        criterion: nn.Module,
        scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List],  # pylint: disable=protected-access
        config: Coqpit,
        optimizer_idx: int = None,
    ) -> Tuple[Dict, Dict, int]:
        """Perform a forward - backward pass and run the optimizer.

        Args:
            batch (Dict): Input batch. If
            model (nn.Module): Model for training. Defaults to None.
            optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use.
            scaler (AMPScaler): AMP scaler.
            criterion (nn.Module): Model's criterion.
            scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used by the optimizer.
            config (Coqpit): Model config.
            optimizer_idx (int, optional): Target optimizer being used. Defaults to None.

        Raises:
            RuntimeError: When the loss is NaN.

        Returns:
            Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm.
        """

        step_start_time = time.time()
        # zero-out optimizer
        optimizer.zero_grad()

        # forward pass and loss computation
        with torch.cuda.amp.autocast(enabled=config.mixed_precision):
            if optimizer_idx is not None:
                outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx)
            else:
                outputs, loss_dict = self._model_train_step(batch, model, criterion)

        # skip the rest
        if outputs is None:
            step_time = time.time() - step_start_time
            return None, {}, step_time

        # # check nan loss
        # if torch.isnan(loss_dict["loss"]).any():
        #     raise RuntimeError(f" > NaN loss detected  - {loss_dict}")

        # set gradient clipping threshold
        if "grad_clip" in config and config.grad_clip is not None:
            if optimizer_idx is not None:
                grad_clip = config.grad_clip[optimizer_idx]
            else:
                grad_clip = config.grad_clip
        else:
            grad_clip = 0.0  # meaning no gradient clipping

        if grad_clip <= 0:
            grad_norm = 0

        # optimizer step
        update_lr_scheduler = True
        if self.use_amp_scaler:
            if self.use_apex:
                # TODO: verify AMP use for GAN training in TTS
                # https://nvidia.github.io/apex/advanced.html?highlight=accumulate#backward-passes-with-multiple-optimizers
                with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
                    scaled_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), grad_clip, error_if_nonfinite=False
                )
            else:
                # model optimizer step in mixed precision mode
                scaler.scale(loss_dict["loss"]).backward()
                if grad_clip > 0:
                    scaler.unscale_(optimizer)
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.master_params(optimizer), grad_clip, error_if_nonfinite=False
                    )
                    # pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
                    if torch.isnan(grad_norm) or torch.isinf(grad_norm):
                        grad_norm = 0
                scale_prev = scaler.get_scale()
                scaler.step(optimizer)
                scaler.update()
                update_lr_scheduler = scale_prev <= scaler.get_scale()
        else:
            # main model optimizer step
            loss_dict["loss"].backward()
            if grad_clip > 0:
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
            optimizer.step()

        step_time = time.time() - step_start_time

        # setup lr
        if scheduler is not None and update_lr_scheduler and not self.config.scheduler_after_epoch:
            scheduler.step()

        # detach losses
        loss_dict = self._detach_loss_dict(loss_dict)
        if optimizer_idx is not None:
            loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss")
            loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm
        else:
            loss_dict["grad_norm"] = grad_norm
        return outputs, loss_dict, step_time

    @staticmethod
    def _detach_loss_dict(loss_dict: Dict) -> Dict:
        """Detach loss values from autograp.

        Args:
            loss_dict (Dict): losses.

        Returns:
            Dict: losses detached from autograph.
        """
        loss_dict_detached = {}
        for key, value in loss_dict.items():
            if isinstance(value, (int, float)):
                loss_dict_detached[key] = value
            else:
                loss_dict_detached[key] = value.item()
        return loss_dict_detached

    def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
        """Perform a training step on a batch of inputs and log the process.

        Args:
            batch (Dict): Input batch.
            batch_n_steps (int): Number of steps needed to complete an epoch. Needed for logging.
            step (int): Current step number in this epoch.
            loader_start_time (float): The time when the data loading is started. Needed for logging.

        Returns:
            Tuple[Dict, Dict]: Model outputs and losses.
        """
        self.callbacks.on_train_step_start()
        # format data
        batch = self.format_batch(batch)
        loader_time = time.time() - loader_start_time

        # conteainers to hold model outputs and losses for each optimizer.
        outputs_per_optimizer = None
        loss_dict = {}
        if not isinstance(self.optimizer, list):
            # training with a single optimizer
            outputs, loss_dict_new, step_time = self._optimize(
                batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config
            )
            loss_dict.update(loss_dict_new)
        else:
            # training with multiple optimizers (e.g. GAN)
            outputs_per_optimizer = [None] * len(self.optimizer)
            total_step_time = 0
            for idx, optimizer in enumerate(self.optimizer):
                criterion = self.criterion
                # scaler = self.scaler[idx] if self.use_amp_scaler else None
                scaler = self.scaler
                scheduler = self.scheduler[idx]
                outputs, loss_dict_new, step_time = self._optimize(
                    batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
                )
                # skip the rest if the model returns None
                total_step_time += step_time
                outputs_per_optimizer[idx] = outputs
                # merge loss_dicts from each optimizer
                # rename duplicates with the optimizer idx
                # if None, model skipped this optimizer
                if loss_dict_new is not None:
                    for k, v in loss_dict_new.items():
                        if k in loss_dict:
                            loss_dict[f"{k}-{idx}"] = v
                        else:
                            loss_dict[k] = v
                step_time = total_step_time
            outputs = outputs_per_optimizer

        # update avg runtime stats
        keep_avg_update = {}
        keep_avg_update["avg_loader_time"] = loader_time
        keep_avg_update["avg_step_time"] = step_time
        self.keep_avg_train.update_values(keep_avg_update)

        # update avg loss stats
        update_eval_values = {}
        for key, value in loss_dict.items():
            update_eval_values["avg_" + key] = value
        self.keep_avg_train.update_values(update_eval_values)

        # print training progress
        if self.total_steps_done % self.config.print_step == 0:
            # log learning rates
            lrs = {}
            if isinstance(self.optimizer, list):
                for idx, optimizer in enumerate(self.optimizer):
                    current_lr = self.optimizer[idx].param_groups[0]["lr"]
                    lrs.update({f"current_lr_{idx}": current_lr})
            else:
                current_lr = self.optimizer.param_groups[0]["lr"]
                lrs = {"current_lr": current_lr}

            # log run-time stats
            loss_dict.update(lrs)
            loss_dict.update(
                {
                    "step_time": round(step_time, 4),
                    "loader_time": round(loader_time, 4),
                }
            )
            self.c_logger.print_train_step(
                batch_n_steps, step, self.total_steps_done, loss_dict, self.keep_avg_train.avg_values
            )

        if self.args.rank == 0:
            # Plot Training Iter Stats
            # reduce TB load and don't log every step
            if self.total_steps_done % self.config.plot_step == 0:
                self.dashboard_logger.train_step_stats(self.total_steps_done, loss_dict)
            if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
                if self.config.checkpoint:
                    # checkpoint the model
                    target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train)
                    save_checkpoint(
                        self.config,
                        self.model,
                        self.optimizer,
                        self.scaler if self.use_amp_scaler else None,
                        self.total_steps_done,
                        self.epochs_done,
                        self.output_path,
                        model_loss=target_avg_loss,
                    )

                    if self.total_steps_done % self.config.log_model_step == 0:
                        # log checkpoint as artifact
                        aliases = [f"epoch-{self.epochs_done}", f"step-{self.total_steps_done}"]
                        self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases)

                # training visualizations
                figures, audios = None, None
                if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"):
                    figures, audios = self.model.module.train_log(self.ap, batch, outputs)
                elif hasattr(self.model, "train_log"):
                    figures, audios = self.model.train_log(self.ap, batch, outputs)
                if figures is not None:
                    self.dashboard_logger.train_figures(self.total_steps_done, figures)
                if audios is not None:
                    self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)

            self.dashboard_logger.flush()

        self.total_steps_done += 1
        self.callbacks.on_train_step_end()
        return outputs, loss_dict

    def train_epoch(self) -> None:
        """Main entry point for the training loop. Run training on the all training samples."""
        self.train_loader = self.get_train_dataloader(
            self.ap,
            self.data_train,
            verbose=True,
        )
        if self.num_gpus > 1:
            self.model.module.train()
        else:
            self.model.train()
        epoch_start_time = time.time()
        if self.use_cuda:
            batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus))
        else:
            batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
        self.c_logger.print_train_start()
        loader_start_time = time.time()
        for cur_step, batch in enumerate(self.train_loader):
            _, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
            loader_start_time = time.time()
        epoch_time = time.time() - epoch_start_time
        # Plot self.epochs_done Stats
        if self.args.rank == 0:
            epoch_stats = {"epoch_time": epoch_time}
            epoch_stats.update(self.keep_avg_train.avg_values)
            self.dashboard_logger.train_epoch_stats(self.total_steps_done, epoch_stats)
            if self.config.model_param_stats:
                self.logger.model_weights(self.model, self.total_steps_done)
        # scheduler step after the epoch
        if self.scheduler is not None and self.config.scheduler_after_epoch:
            if isinstance(self.scheduler, list):
                for scheduler in self.scheduler:
                    if scheduler is not None:
                        scheduler.step()
            else:
                self.scheduler.step()

    @staticmethod
    def _model_eval_step(
        batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None
    ) -> Tuple[Dict, Dict]:
        """
        Perform a evaluation forward pass. Compute model outputs and losses with no gradients.

        Args:
            batch (Dict): IBatch of inputs.
            model (nn.Module): Model to call evaluation.
            criterion (nn.Module): Model criterion.
            optimizer_idx (int, optional): Optimizer ID to define the closure in multi-optimizer training. Defaults to None.

        Returns:
            Tuple[Dict, Dict]: model outputs and losses.
        """
        input_args = [batch, criterion]
        if optimizer_idx is not None:
            input_args.append(optimizer_idx)
        if hasattr(model, "module"):
            return model.module.eval_step(*input_args)
        return model.eval_step(*input_args)

    def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
        """Perform a evaluation step on a batch of inputs and log the process.

        Args:
            batch (Dict): Input batch.
            step (int): Current step number in this epoch.

        Returns:
            Tuple[Dict, Dict]: Model outputs and losses.
        """
        with torch.no_grad():
            outputs = []
            loss_dict = {}
            if not isinstance(self.optimizer, list):
                outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion)
            else:
                outputs = [None] * len(self.optimizer)
                for idx, _ in enumerate(self.optimizer):
                    criterion = self.criterion
                    outputs_, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx)
                    outputs[idx] = outputs_

                    if loss_dict_new is not None:
                        loss_dict_new[f"loss_{idx}"] = loss_dict_new.pop("loss")
                        loss_dict.update(loss_dict_new)

                    loss_dict = self._detach_loss_dict(loss_dict)

            # update avg stats
            update_eval_values = {}
            for key, value in loss_dict.items():
                update_eval_values["avg_" + key] = value
            self.keep_avg_eval.update_values(update_eval_values)

            if self.config.print_eval:
                self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values)
        return outputs, loss_dict

    def eval_epoch(self) -> None:
        """Main entry point for the evaluation loop. Run evaluation on the all validation samples."""
        self.eval_loader = (
            self.get_eval_dataloader(
                self.ap,
                self.data_eval,
                verbose=True,
            )
            if self.config.run_eval
            else None
        )

        self.model.eval()
        self.c_logger.print_eval_start()
        loader_start_time = time.time()
        batch = None
        for cur_step, batch in enumerate(self.eval_loader):
            # format data
            batch = self.format_batch(batch)
            loader_time = time.time() - loader_start_time
            self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
            outputs, _ = self.eval_step(batch, cur_step)
            loader_start_time = time.time()
        # plot epoch stats, artifacts and figures
        if self.args.rank == 0:
            figures, audios = None, None
            if hasattr(self.model, "module") and hasattr(self.model.module, "eval_log"):
                figures, audios = self.model.module.eval_log(self.ap, batch, outputs)
            elif hasattr(self.model, "eval_log"):
                figures, audios = self.model.eval_log(self.ap, batch, outputs)
            if figures is not None:
                self.dashboard_logger.eval_figures(self.total_steps_done, figures)
            if audios is not None:
                self.dashboard_logger.eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
            self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)

    def test_run(self) -> None:
        """Run test and log the results. Test run must be defined by the model.
        Model must return figures and audios to be logged by the Tensorboard."""
        if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")):
            if self.eval_loader is None:
                self.eval_loader = self.get_eval_dataloader(
                    self.ap,
                    self.data_eval,
                    verbose=True,
                )

            if hasattr(self.eval_loader.dataset, "load_test_samples"):
                samples = self.eval_loader.dataset.load_test_samples(1)
                if self.num_gpus > 1:
                    figures, audios = self.model.module.test_run(self.ap, samples, None)
                else:
                    figures, audios = self.model.test_run(self.ap, samples, None)
            else:
                if self.num_gpus > 1:
                    figures, audios = self.model.module.test_run(self.ap)
                else:
                    figures, audios = self.model.test_run(self.ap)
            self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
            self.dashboard_logger.test_figures(self.total_steps_done, figures)

    def _restore_best_loss(self):
        """Restore the best loss from the args.best_path if provided else
        from the model (`args.restore_path` or `args.continue_path`) used for resuming the training"""
        if self.restore_step != 0 or self.args.best_path:
            print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
            ch = load_fsspec(self.args.restore_path, map_location="cpu")
            if "model_loss" in ch:
                self.best_loss = ch["model_loss"]
            print(f" > Starting with loaded last best loss {self.best_loss}.")

    def _fit(self) -> None:
        """🏃 train -> evaluate -> test for the number of epochs."""
        self._restore_best_loss()

        self.total_steps_done = self.restore_step

        for epoch in range(0, self.config.epochs):
            if self.num_gpus > 1:
                # let all processes sync up before starting with a new epoch of training
                dist.barrier()
            self.callbacks.on_epoch_start()
            self.keep_avg_train = KeepAverage()
            self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
            self.epochs_done = epoch
            self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path)
            self.train_epoch()
            if self.config.run_eval:
                self.eval_epoch()
            if epoch >= self.config.test_delay_epochs and self.args.rank <= 0:
                self.test_run()
            self.c_logger.print_epoch_end(
                epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
            )
            if self.args.rank in [None, 0]:
                self.save_best_model()
            self.callbacks.on_epoch_end()

    def fit(self) -> None:
        """Where the ✨️magic✨️ happens..."""
        try:
            self._fit()
            if self.args.rank == 0:
                self.dashboard_logger.finish()
        except KeyboardInterrupt:
            self.callbacks.on_keyboard_interrupt()
            # if the output folder is empty remove the run.
            remove_experiment_folder(self.output_path)
            # clear the DDP processes
            if self.num_gpus > 1:
                dist.destroy_process_group()
            # finish the wandb run and sync data
            if self.args.rank == 0:
                self.dashboard_logger.finish()
            # stop without error signal
            try:
                sys.exit(0)
            except SystemExit:
                os._exit(0)  # pylint: disable=protected-access
        except BaseException:  # pylint: disable=broad-except
            remove_experiment_folder(self.output_path)
            traceback.print_exc()
            sys.exit(1)

    def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
        """Pick the target loss to compare models"""
        target_avg_loss = None

        # return if target loss defined in the model config
        if "target_loss" in self.config and self.config.target_loss:
            return keep_avg_target[f"avg_{self.config.target_loss}"]

        # take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
        if isinstance(self.optimizer, list):
            target_avg_loss = 0
            for idx in range(len(self.optimizer)):
                target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
            target_avg_loss /= len(self.optimizer)
        else:
            target_avg_loss = keep_avg_target["avg_loss"]
        return target_avg_loss

    def save_best_model(self) -> None:
        """Save the best model. It only saves if the current target loss is smaller then the previous."""

        # set the target loss to choose the best model
        target_loss_dict = self._pick_target_avg_loss(self.keep_avg_eval if self.keep_avg_eval else self.keep_avg_train)

        # save the model and update the best_loss
        self.best_loss = save_best_model(
            target_loss_dict,
            self.best_loss,
            self.config,
            self.model,
            self.optimizer,
            self.scaler if self.use_amp_scaler else None,
            self.total_steps_done,
            self.epochs_done,
            self.output_path,
            keep_all_best=self.config.keep_all_best,
            keep_after=self.config.keep_after,
        )

    def _setup_logger_config(self, log_file: str) -> None:
        """Write log strings to a file and print logs to the terminal.
        TODO: Causes formatting issues in pdb debugging."""

        class Logger(object):
            def __init__(self, print_to_terminal=True):
                self.print_to_terminal = print_to_terminal
                self.terminal = sys.stdout
                self.log_file = log_file

            def write(self, message):
                if self.print_to_terminal:
                    self.terminal.write(message)
                with open(self.log_file, "a", encoding="utf-8") as f:
                    f.write(message)

            def flush(self):
                # this flush method is needed for python 3 compatibility.
                # this handles the flush command by doing nothing.
                # you might want to specify some extra behavior here.
                pass

        # don't let processes rank > 0 write to the terminal
        sys.stdout = Logger(self.args.rank == 0)

    @staticmethod
    def _is_apex_available() -> bool:
        """Check if Nvidia's APEX is available."""
        return importlib.util.find_spec("apex") is not None

    @staticmethod
    def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]:
        """Receive the optimizer from the model if model implements `get_optimizer()` else
        check the optimizer parameters in the config and try initiating the optimizer.

        Args:
            model (nn.Module): Training model.
            config (Coqpit): Training configuration.

        Returns:
            Union[torch.optim.Optimizer, List]: A optimizer or a list of optimizers. GAN models define a list.
        """
        if hasattr(model, "get_optimizer"):
            optimizer = model.get_optimizer()
        if optimizer is None:
            optimizer_name = config.optimizer
            optimizer_params = config.optimizer_params
            return get_optimizer(optimizer_name, optimizer_params, config.lr, model)
        return optimizer

    @staticmethod
    def get_lr(model: nn.Module, config: Coqpit) -> Union[float, List[float]]:
        """Set the initial learning rate by the model if model implements `get_lr()` else try setting the learning rate
        fromthe config.

        Args:
            model (nn.Module): Training model.
            config (Coqpit): Training configuration.

        Returns:
            Union[float, List[float]]: A single learning rate or a list of learning rates, one for each optimzier.
        """
        lr = None
        if hasattr(model, "get_lr"):
            lr = model.get_lr()
        if lr is None:
            lr = config.lr
        return lr

    @staticmethod
    def get_scheduler(
        model: nn.Module, config: Coqpit, optimizer: Union[torch.optim.Optimizer, List]
    ) -> Union[torch.optim.lr_scheduler._LRScheduler, List]:  # pylint: disable=protected-access
        """Receive the scheduler from the model if model implements `get_scheduler()` else
        check the config and try initiating the scheduler.

        Args:
            model (nn.Module): Training model.
            config (Coqpit): Training configuration.

        Returns:
            Union[torch.optim.Optimizer, List]: A scheduler or a list of schedulers, one for each optimizer.
        """
        scheduler = None
        if hasattr(model, "get_scheduler"):
            scheduler = model.get_scheduler(optimizer)
        if scheduler is None:
            lr_scheduler = config.lr_scheduler
            lr_scheduler_params = config.lr_scheduler_params
            return get_scheduler(lr_scheduler, lr_scheduler_params, optimizer)
        return scheduler

    @staticmethod
    def get_criterion(model: nn.Module) -> nn.Module:
        """Receive the criterion from the model. Model must implement `get_criterion()`.

        Args:
            model (nn.Module): Training model.

        Returns:
            nn.Module: Criterion layer.
        """
        criterion = None
        criterion = model.get_criterion()
        return criterion


def getarguments():
    train_config = TrainingArgs()
    parser = train_config.init_argparse(arg_prefix="")
    return parser


def get_last_checkpoint(path: str) -> Tuple[str, str]:
    """Get latest checkpoint or/and best model in path.

    It is based on globbing for `*.pth.tar` and the RegEx
    `(checkpoint|best_model)_([0-9]+)`.

    Args:
        path: Path to files to be compared.

    Raises:
        ValueError: If no checkpoint or best_model files are found.

    Returns:
        Path to the last checkpoint
        Path to best checkpoint
    """
    fs = fsspec.get_mapper(path).fs
    file_names = fs.glob(os.path.join(path, "*.pth.tar"))
    scheme = urlparse(path).scheme
    if scheme:  # scheme is not preserved in fs.glob, add it back
        file_names = [scheme + "://" + file_name for file_name in file_names]
    last_models = {}
    last_model_nums = {}
    for key in ["checkpoint", "best_model"]:
        last_model_num = None
        last_model = None
        # pass all the checkpoint files and find
        # the one with the largest model number suffix.
        for file_name in file_names:
            match = re.search(f"{key}_([0-9]+)", file_name)
            if match is not None:
                model_num = int(match.groups()[0])
                if last_model_num is None or model_num > last_model_num:
                    last_model_num = model_num
                    last_model = file_name

        # if there is no checkpoint found above
        # find the checkpoint with the latest
        # modification date.
        key_file_names = [fn for fn in file_names if key in fn]
        if last_model is None and len(key_file_names) > 0:
            last_model = max(key_file_names, key=os.path.getctime)
            last_model_num = load_fsspec(last_model)["step"]

        if last_model is not None:
            last_models[key] = last_model
            last_model_nums[key] = last_model_num

    # check what models were found
    if not last_models:
        raise ValueError(f"No models found in continue path {path}!")
    if "checkpoint" not in last_models:  # no checkpoint just best model
        last_models["checkpoint"] = last_models["best_model"]
    elif "best_model" not in last_models:  # no best model
        # this shouldn't happen, but let's handle it just in case
        last_models["best_model"] = last_models["checkpoint"]
    # finally check if last best model is more recent than checkpoint
    elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
        last_models["checkpoint"] = last_models["best_model"]

    return last_models["checkpoint"], last_models["best_model"]


def process_args(args, config=None):
    """Process parsed comand line arguments and initialize the config if not provided.

    Args:
        args (argparse.Namespace or dict like): Parsed input arguments.
        config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.

    Returns:
        c (TTS.utils.io.AttrDict): Config paramaters.
        out_path (str): Path to save models and logging.
        audio_path (str): Path to save generated test audios.
        c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
            logging to the console.

        dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging

    TODO:
        - Interactive config definition.
    """
    if isinstance(args, tuple):
        args, coqpit_overrides = args
    if args.continue_path:
        # continue a previous training from its output folder
        experiment_path = args.continue_path
        args.config_path = os.path.join(args.continue_path, "config.json")
        args.restore_path, best_model = get_last_checkpoint(args.continue_path)
        if not args.best_path:
            args.best_path = best_model
    # init config if not already defined
    if config is None:
        if args.config_path:
            # init from a file
            config = load_config(args.config_path)
        else:
            # init from console args
            from TTS.config.shared_configs import BaseTrainingConfig  # pylint: disable=import-outside-toplevel

            config_base = BaseTrainingConfig()
            config_base.parse_known_args(coqpit_overrides)
            config = register_config(config_base.model)()
    # override values from command-line args
    config.parse_known_args(coqpit_overrides, relaxed_parser=True)
    experiment_path = args.continue_path
    if not experiment_path:
        experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
    audio_path = os.path.join(experiment_path, "test_audios")
    config.output_log_path = experiment_path
    # setup rank 0 process in distributed training
    dashboard_logger = None
    if args.rank == 0:
        new_fields = {}
        if args.restore_path:
            new_fields["restore_path"] = args.restore_path
        new_fields["github_branch"] = get_git_branch()
        # if model characters are not set in the config file
        # save the default set to the config file for future
        # compatibility.
        if config.has("characters") and config.characters is None:
            used_characters = parse_symbols()
            new_fields["characters"] = used_characters
        copy_model_files(config, experiment_path, new_fields)
        dashboard_logger = init_dashboard_logger(config)
    c_logger = ConsoleLogger()
    return config, experiment_path, audio_path, c_logger, dashboard_logger


def init_arguments():
    train_config = TrainingArgs()
    parser = train_config.init_argparse(arg_prefix="")
    return parser


def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
    """Initialization of a training run."""
    if isinstance(argv, Coqpit):
        parser = argv.init_argparse(arg_prefix="")
    else:
        parser = init_arguments()
    args = parser.parse_known_args()
    config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config)
    return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger