import math

import torch
from torch import nn

from TTS.tts.layers.generic.gated_conv import GatedConvBlock
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.glow import ResidualConv1dLayerNormBlock
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.utils.helpers import sequence_mask


class Encoder(nn.Module):
    """Glow-TTS encoder module.

    ::

        embedding -> <prenet> -> encoder_module -> <postnet> --> proj_mean
                                                             |
                                                             |-> proj_var
                                                             |
                                                             |-> concat -> duration_predictor
                                                                    ↑
                                                              speaker_embed

    Args:
        num_chars (int): number of characters.
        out_channels (int): number of output channels.
        hidden_channels (int): encoder's embedding size.
        hidden_channels_ffn (int): transformer's feed-forward channels.
        kernel_size (int): kernel size for conv layers and duration predictor.
        dropout_p (float): dropout rate for any dropout layer.
        mean_only (bool): if True, output only mean values and use constant std.
        use_prenet (bool): if True, use pre-convolutional layers before transformer layers.
        c_in_channels (int): number of channels in conditional input.

    Shapes:
        - input: (B, T, C)

    ::

        suggested encoder params...

        for encoder_type == 'rel_pos_transformer'
            encoder_params={
                'kernel_size':3,
                'dropout_p': 0.1,
                'num_layers': 6,
                'num_heads': 2,
                'hidden_channels_ffn': 768,  # 4 times the hidden_channels
                'input_length': None
            }

        for encoder_type == 'gated_conv'
            encoder_params={
                'kernel_size':5,
                'dropout_p': 0.1,
                'num_layers': 9,
            }

        for encoder_type == 'residual_conv_bn'
            encoder_params={
                "kernel_size": 4,
                "dilations": [1, 2, 4, 1, 2, 4, 1, 2, 4, 1, 2, 4, 1],
                "num_conv_blocks": 2,
                "num_res_blocks": 13
            }

         for encoder_type == 'time_depth_separable'
            encoder_params={
                "kernel_size": 5,
                'num_layers': 9,
            }
    """

    def __init__(
        self,
        num_chars,
        out_channels,
        hidden_channels,
        hidden_channels_dp,
        encoder_type,
        encoder_params,
        dropout_p_dp=0.1,
        mean_only=False,
        use_prenet=True,
        c_in_channels=0,
    ):
        super().__init__()
        # class arguments
        self.num_chars = num_chars
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels
        self.hidden_channels_dp = hidden_channels_dp
        self.dropout_p_dp = dropout_p_dp
        self.mean_only = mean_only
        self.use_prenet = use_prenet
        self.c_in_channels = c_in_channels
        self.encoder_type = encoder_type
        # embedding layer
        self.emb = nn.Embedding(num_chars, hidden_channels)
        nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
        # init encoder module
        if encoder_type.lower() == "rel_pos_transformer":
            if use_prenet:
                self.prenet = ResidualConv1dLayerNormBlock(
                    hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5
                )
            self.encoder = RelativePositionTransformer(
                hidden_channels, hidden_channels, hidden_channels, **encoder_params
            )
        elif encoder_type.lower() == "gated_conv":
            self.encoder = GatedConvBlock(hidden_channels, **encoder_params)
        elif encoder_type.lower() == "residual_conv_bn":
            if use_prenet:
                self.prenet = nn.Sequential(nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU())
            self.encoder = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **encoder_params)
            self.postnet = nn.Sequential(
                nn.Conv1d(self.hidden_channels, self.hidden_channels, 1), nn.BatchNorm1d(self.hidden_channels)
            )
        elif encoder_type.lower() == "time_depth_separable":
            if use_prenet:
                self.prenet = ResidualConv1dLayerNormBlock(
                    hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5
                )
            self.encoder = TimeDepthSeparableConvBlock(
                hidden_channels, hidden_channels, hidden_channels, **encoder_params
            )
        else:
            raise ValueError(" [!] Unkown encoder type.")

        # final projection layers
        self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1)
        if not mean_only:
            self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1)
        # duration predictor
        self.duration_predictor = DurationPredictor(
            hidden_channels + c_in_channels, hidden_channels_dp, 3, dropout_p_dp
        )

    def forward(self, x, x_lengths, g=None):
        """
        Shapes:
            - x: :math:`[B, C, T]`
            - x_lengths: :math:`[B]`
            - g (optional): :math:`[B, 1, T]`
        """
        # embedding layer
        # [B ,T, D]
        x = self.emb(x) * math.sqrt(self.hidden_channels)
        # [B, D, T]
        x = torch.transpose(x, 1, -1)
        # compute input sequence mask
        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
        # prenet
        if hasattr(self, "prenet") and self.use_prenet:
            x = self.prenet(x, x_mask)
        # encoder
        x = self.encoder(x, x_mask)
        # postnet
        if hasattr(self, "postnet"):
            x = self.postnet(x) * x_mask
        # set duration predictor input
        if g is not None:
            g_exp = g.expand(-1, -1, x.size(-1))
            x_dp = torch.cat([x.detach(), g_exp], 1)
        else:
            x_dp = x.detach()
        # final projection layer
        x_m = self.proj_m(x) * x_mask
        if not self.mean_only:
            x_logs = self.proj_s(x) * x_mask
        else:
            x_logs = torch.zeros_like(x_m)
        # duration predictor
        logw = self.duration_predictor(x_dp, x_mask)
        return x_m, x_logs, logw, x_mask