from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
from coqpit import Coqpit
from torch import nn

from TTS.utils.audio import AudioProcessor

# pylint: skip-file


class BaseModel(nn.Module, ABC):
    """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.

    Notes on input/output tensor shapes:
        Any input or output tensor of the model must be shaped as

        - 3D tensors `batch x time x channels`
        - 2D tensors `batch x channels`
        - 1D tensors `batch x 1`
    """

    @abstractmethod
    def forward(self, text: torch.Tensor, aux_input={}, **kwargs) -> Dict:
        """Forward pass for the model mainly used in training.

        You can be flexible here and use different number of arguments and argument names since it is mostly used by
        `train_step()` in training whitout exposing it to the out of the class.

        Args:
            text (torch.Tensor): Input text character sequence ids.
            aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs.
                for the model.

        Returns:
            Dict: model outputs. This must include an item keyed `model_outputs` as the final artifact of the model.
        """
        outputs_dict = {"model_outputs": None}
        ...
        return outputs_dict

    @abstractmethod
    def inference(self, text: torch.Tensor, aux_input={}) -> Dict:
        """Forward pass for inference.

        After the model is trained this is the only function that connects the model the out world.

        This function must only take a `text` input and a dictionary that has all the other model specific inputs.
        We don't use `*kwargs` since it is problematic with the TorchScript API.

        Args:
            text (torch.Tensor): [description]
            aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc.

        Returns:
            Dict: [description]
        """
        outputs_dict = {"model_outputs": None}
        ...
        return outputs_dict

    @abstractmethod
    def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
        """Perform a single training step. Run the model forward pass and compute losses.

        Args:
            batch (Dict): Input tensors.
            criterion (nn.Module): Loss layer designed for the model.

        Returns:
            Tuple[Dict, Dict]: Model ouputs and computed losses.
        """
        outputs_dict = {}
        loss_dict = {}  # this returns from the criterion
        ...
        return outputs_dict, loss_dict

    def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
        """Create visualizations and waveform examples for training.

        For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
        be projected onto Tensorboard.

        Args:
            ap (AudioProcessor): audio processor used at training.
            batch (Dict): Model inputs used at the previous training step.
            outputs (Dict): Model outputs generated at the previoud training step.

        Returns:
            Tuple[Dict, np.ndarray]: training plots and output waveform.
        """
        return None, None

    @abstractmethod
    def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
        """Perform a single evaluation step. Run the model forward pass and compute losses. In most cases, you can
        call `train_step()` with no changes.

        Args:
            batch (Dict): Input tensors.
            criterion (nn.Module): Loss layer designed for the model.

        Returns:
            Tuple[Dict, Dict]: Model ouputs and computed losses.
        """
        outputs_dict = {}
        loss_dict = {}  # this returns from the criterion
        ...
        return outputs_dict, loss_dict

    def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
        """The same as `train_log()`"""
        return None, None

    @abstractmethod
    def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None:
        """Load a checkpoint and get ready for training or inference.

        Args:
            config (Coqpit): Model configuration.
            checkpoint_path (str): Path to the model checkpoint file.
            eval (bool, optional): If true, init model for inference else for training. Defaults to False.
        """
        ...

    def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
        """Setup an return optimizer or optimizers."""
        pass

    def get_lr(self) -> Union[float, List[float]]:
        """Return learning rate(s).

        Returns:
            Union[float, List[float]]: Model's initial learning rates.
        """
        pass

    def get_scheduler(self, optimizer: torch.optim.Optimizer):
        pass

    def get_criterion(self):
        pass

    def format_batch(self):
        pass