import torch
from torch import nn

from TTS.tts.layers.generic.res_conv_bn import Conv1dBN, Conv1dBNBlock, ResidualConv1dBNBlock
from TTS.tts.layers.generic.transformer import FFTransformerBlock
from TTS.tts.layers.generic.wavenet import WNBlocks
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer


class WaveNetDecoder(nn.Module):
    """WaveNet based decoder with a prenet and a postnet.

    prenet: conv1d_1x1
    postnet: 3 x [conv1d_1x1 -> relu] -> conv1d_1x1

    TODO: Integrate speaker conditioning vector.

    Note:
        default wavenet parameters;
            params = {
                "num_blocks": 12,
                "hidden_channels":192,
                "kernel_size": 5,
                "dilation_rate": 1,
                "num_layers": 4,
                "dropout_p": 0.05
            }

    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        hidden_channels (int): number of hidden channels for prenet and postnet.
        params (dict): dictionary for residual convolutional blocks.
    """

    def __init__(self, in_channels, out_channels, hidden_channels, c_in_channels, params):
        super().__init__()
        # prenet
        self.prenet = torch.nn.Conv1d(in_channels, params["hidden_channels"], 1)
        # wavenet layers
        self.wn = WNBlocks(params["hidden_channels"], c_in_channels=c_in_channels, **params)
        # postnet
        self.postnet = [
            torch.nn.Conv1d(params["hidden_channels"], hidden_channels, 1),
            torch.nn.ReLU(),
            torch.nn.Conv1d(hidden_channels, hidden_channels, 1),
            torch.nn.ReLU(),
            torch.nn.Conv1d(hidden_channels, hidden_channels, 1),
            torch.nn.ReLU(),
            torch.nn.Conv1d(hidden_channels, out_channels, 1),
        ]
        self.postnet = nn.Sequential(*self.postnet)

    def forward(self, x, x_mask=None, g=None):
        x = self.prenet(x) * x_mask
        x = self.wn(x, x_mask, g)
        o = self.postnet(x) * x_mask
        return o


class RelativePositionTransformerDecoder(nn.Module):
    """Decoder with Relative Positional Transformer.

    Note:
        Default params
            params={
                'hidden_channels_ffn': 128,
                'num_heads': 2,
                "kernel_size": 3,
                "dropout_p": 0.1,
                "num_layers": 8,
                "rel_attn_window_size": 4,
                "input_length": None
            }

    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        hidden_channels (int): number of hidden channels including Transformer layers.
        params (dict): dictionary for residual convolutional blocks.
    """

    def __init__(self, in_channels, out_channels, hidden_channels, params):

        super().__init__()
        self.prenet = Conv1dBN(in_channels, hidden_channels, 1, 1)
        self.rel_pos_transformer = RelativePositionTransformer(in_channels, out_channels, hidden_channels, **params)

    def forward(self, x, x_mask=None, g=None):  # pylint: disable=unused-argument
        o = self.prenet(x) * x_mask
        o = self.rel_pos_transformer(o, x_mask)
        return o


class FFTransformerDecoder(nn.Module):
    """Decoder with FeedForwardTransformer.

    Default params
            params={
                'hidden_channels_ffn': 1024,
                'num_heads': 2,
                "dropout_p": 0.1,
                "num_layers": 6,
            }

    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        hidden_channels (int): number of hidden channels including Transformer layers.
        params (dict): dictionary for residual convolutional blocks.
    """

    def __init__(self, in_channels, out_channels, params):

        super().__init__()
        self.transformer_block = FFTransformerBlock(in_channels, **params)
        self.postnet = nn.Conv1d(in_channels, out_channels, 1)

    def forward(self, x, x_mask=None, g=None):  # pylint: disable=unused-argument
        # TODO: handle multi-speaker
        x_mask = 1 if x_mask is None else x_mask
        o = self.transformer_block(x) * x_mask
        o = self.postnet(o) * x_mask
        return o


class ResidualConv1dBNDecoder(nn.Module):
    """Residual Convolutional Decoder as in the original Speedy Speech paper

    TODO: Integrate speaker conditioning vector.

    Note:
        Default params
                params = {
                    "kernel_size": 4,
                    "dilations": 4 * [1, 2, 4, 8] + [1],
                    "num_conv_blocks": 2,
                    "num_res_blocks": 17
                }

    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        hidden_channels (int): number of hidden channels including ResidualConv1dBNBlock layers.
        params (dict): dictionary for residual convolutional blocks.
    """

    def __init__(self, in_channels, out_channels, hidden_channels, params):
        super().__init__()
        self.res_conv_block = ResidualConv1dBNBlock(in_channels, hidden_channels, hidden_channels, **params)
        self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1)
        self.postnet = nn.Sequential(
            Conv1dBNBlock(
                hidden_channels, hidden_channels, hidden_channels, params["kernel_size"], 1, num_conv_blocks=2
            ),
            nn.Conv1d(hidden_channels, out_channels, 1),
        )

    def forward(self, x, x_mask=None, g=None):  # pylint: disable=unused-argument
        o = self.res_conv_block(x, x_mask)
        o = self.post_conv(o) + x
        return self.postnet(o) * x_mask


class Decoder(nn.Module):
    """Decodes the expanded phoneme encoding into spectrograms
    Args:
        out_channels (int): number of output channels.
        in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers.
        decoder_type (str): decoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'.
        decoder_params (dict): model parameters for specified decoder type.
        c_in_channels (int): number of channels for conditional input.

    Shapes:
        - input: (B, C, T)
    """

    # pylint: disable=dangerous-default-value
    def __init__(
        self,
        out_channels,
        in_hidden_channels,
        decoder_type="residual_conv_bn",
        decoder_params={
            "kernel_size": 4,
            "dilations": 4 * [1, 2, 4, 8] + [1],
            "num_conv_blocks": 2,
            "num_res_blocks": 17,
        },
        c_in_channels=0,
    ):
        super().__init__()

        if decoder_type.lower() == "relative_position_transformer":
            self.decoder = RelativePositionTransformerDecoder(
                in_channels=in_hidden_channels,
                out_channels=out_channels,
                hidden_channels=in_hidden_channels,
                params=decoder_params,
            )
        elif decoder_type.lower() == "residual_conv_bn":
            self.decoder = ResidualConv1dBNDecoder(
                in_channels=in_hidden_channels,
                out_channels=out_channels,
                hidden_channels=in_hidden_channels,
                params=decoder_params,
            )
        elif decoder_type.lower() == "wavenet":
            self.decoder = WaveNetDecoder(
                in_channels=in_hidden_channels,
                out_channels=out_channels,
                hidden_channels=in_hidden_channels,
                c_in_channels=c_in_channels,
                params=decoder_params,
            )
        elif decoder_type.lower() == "fftransformer":
            self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params)
        else:
            raise ValueError(f"[!] Unknown decoder type - {decoder_type}")

    def forward(self, x, x_mask, g=None):  # pylint: disable=unused-argument
        """
        Args:
            x: [B, C, T]
            x_mask: [B, 1, T]
            g: [B, C_g, 1]
        """
        # TODO: implement multi-speaker
        o = self.decoder(x, x_mask, g)
        return o