import math
import torch
from torch import nn

from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.generic.res_conv_bn import  ResidualConv1dBNBlock



class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding for non-recurrent neural networks.
    Implementation based on "Attention Is All You Need"
    Args:
       channels (int): embedding size
       dropout (float): dropout parameter
    """
    def __init__(self, channels, dropout=0.0, max_len=5000):
        super().__init__()
        if channels % 2 != 0:
            raise ValueError(
                "Cannot use sin/cos positional encoding with "
                "odd channels (got channels={:d})".format(channels))
        pe = torch.zeros(max_len, channels)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp((torch.arange(0, channels, 2, dtype=torch.float) *
                              -(math.log(10000.0) / channels)))
        pe[:, 0::2] = torch.sin(position.float() * div_term)
        pe[:, 1::2] = torch.cos(position.float() * div_term)
        pe = pe.unsqueeze(0).transpose(1, 2)
        self.register_buffer('pe', pe)
        if dropout > 0:
            self.dropout = nn.Dropout(p=dropout)
        self.channels = channels

    def forward(self, x, mask=None, first_idx=None, last_idx=None):
        """
        Shapes:
            x: [B, C, T]
            mask: [B, 1, T]
            first_idx: int
            last_idx: int
        """

        x = x * math.sqrt(self.channels)
        if first_idx is None:
            if self.pe.size(2) < x.size(2):
                raise RuntimeError(
                    f"Sequence is {x.size(2)} but PositionalEncoding is"
                    f" limited to {self.pe.size(2)}. See max_len argument.")
            if mask is not None:
                pos_enc = (self.pe[:, :, :x.size(2)] * mask)
            else:
                pos_enc = self.pe[:, :, :x.size(2)]
            x = x + pos_enc
        else:
            x = x + self.pe[:, :, first_idx:last_idx]
        if hasattr(self, 'dropout'):
            x = self.dropout(x)
        return x


class RelativePositionTransformerEncoder(nn.Module):
    """Speedy speech encoder built on Transformer with Relative Position encoding.

    TODO: Integrate speaker conditioning vector.

    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        hidden_channels (int): number of hidden channels
        params (dict): dictionary for residual convolutional blocks.
    """
    def __init__(self, in_channels, out_channels, hidden_channels, params):
        super().__init__()
        self.prenet = ResidualConv1dBNBlock(in_channels,
                                            hidden_channels,
                                            hidden_channels,
                                            kernel_size=5,
                                            num_res_blocks=3,
                                            num_conv_blocks=1,
                                            dilations=[1, 1, 1])
        self.rel_pos_transformer = RelativePositionTransformer(
            hidden_channels, out_channels, hidden_channels, **params)

    def forward(self, x, x_mask=None, g=None):  # pylint: disable=unused-argument
        if x_mask is None:
            x_mask = 1
        o = self.prenet(x) * x_mask
        o = self.rel_pos_transformer(o, x_mask)
        return o


class ResidualConv1dBNEncoder(nn.Module):
    """Residual Convolutional Encoder as in the original Speedy Speech paper

    TODO: Integrate speaker conditioning vector.

    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        hidden_channels (int): number of hidden channels
        params (dict): dictionary for residual convolutional blocks.
    """
    def __init__(self, in_channels, out_channels, hidden_channels, params):
        super().__init__()
        self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1),
                                    nn.ReLU())
        self.res_conv_block = ResidualConv1dBNBlock(hidden_channels,
                                                    hidden_channels,
                                                    hidden_channels, **params)

        self.postnet = nn.Sequential(*[
            nn.Conv1d(hidden_channels, hidden_channels, 1),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_channels),
            nn.Conv1d(hidden_channels, out_channels, 1)
        ])

    def forward(self, x, x_mask=None, g=None):  # pylint: disable=unused-argument
        if x_mask is None:
            x_mask = 1
        o = self.prenet(x) * x_mask
        o = self.res_conv_block(o, x_mask)
        o = self.postnet(o + x) * x_mask
        return o * x_mask


class Encoder(nn.Module):
    # pylint: disable=dangerous-default-value
    """Factory class for Speedy Speech encoder enables different encoder types internally.

    Args:
        num_chars (int): number of characters.
        out_channels (int): number of output channels.
        in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers.
        encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'.
        encoder_params (dict): model parameters for specified encoder type.
        c_in_channels (int): number of channels for conditional input.

    Note:
        Default encoder_params...

        for 'transformer'
            encoder_params={
                'hidden_channels_ffn': 128,
                'num_heads': 2,
                "kernel_size": 3,
                "dropout_p": 0.1,
                "num_layers": 6,
                "rel_attn_window_size": 4,
                "input_length": None
            },

        for 'residual_conv_bn'
            encoder_params = {
                "kernel_size": 4,
                "dilations": 4 * [1, 2, 4] + [1],
                "num_conv_blocks": 2,
                "num_res_blocks": 13
            }
    """
    def __init__(
            self,
            in_hidden_channels,
            out_channels,
            encoder_type='residual_conv_bn',
            encoder_params={
                "kernel_size": 4,
                "dilations": 4 * [1, 2, 4] + [1],
                "num_conv_blocks": 2,
                "num_res_blocks": 13
            },
            c_in_channels=0):
        super().__init__()
        self.out_channels = out_channels
        self.in_channels = in_hidden_channels
        self.hidden_channels = in_hidden_channels
        self.encoder_type = encoder_type
        self.c_in_channels = c_in_channels

        # init encoder
        if encoder_type.lower() == "transformer":
            # text encoder
            self.encoder = RelativePositionTransformerEncoder(
                in_hidden_channels, out_channels, in_hidden_channels,
                encoder_params)  # pylint: disable=unexpected-keyword-arg
        elif encoder_type.lower() == 'residual_conv_bn':
            self.encoder = ResidualConv1dBNEncoder(in_hidden_channels,
                                                   out_channels,
                                                   in_hidden_channels,
                                                   encoder_params)
        else:
            raise NotImplementedError(' [!] unknown encoder type.')

        # final projection layers


    def forward(self, x, x_mask, g=None):  # pylint: disable=unused-argument
        """
        Shapes:
            x: [B, C, T]
            x_mask: [B, 1, T]
            g: [B, C, 1]
        """
        o = self.encoder(x, x_mask)
        return o * x_mask