from typing import List

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.utils import parametrize

from TTS.vocoder.layers.lvc_block import LVCBlock

LRELU_SLOPE = 0.1


class UnivnetGenerator(torch.nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        hidden_channels: int,
        cond_channels: int,
        upsample_factors: List[int],
        lvc_layers_each_block: int,
        lvc_kernel_size: int,
        kpnet_hidden_channels: int,
        kpnet_conv_size: int,
        dropout: float,
        use_weight_norm=True,
    ):
        """Univnet Generator network.

        Paper: https://arxiv.org/pdf/2106.07889.pdf

        Args:
            in_channels (int): Number of input tensor channels.
            out_channels (int): Number of channels of the output tensor.
            hidden_channels (int): Number of hidden network channels.
            cond_channels (int): Number of channels of the conditioning tensors.
            upsample_factors (List[int]): List of uplsample factors for the upsampling layers.
            lvc_layers_each_block (int): Number of LVC layers in each block.
            lvc_kernel_size (int): Kernel size of the LVC layers.
            kpnet_hidden_channels (int): Number of hidden channels in the key-point network.
            kpnet_conv_size (int): Number of convolution channels in the key-point network.
            dropout (float): Dropout rate.
            use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True.
        """

        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.cond_channels = cond_channels
        self.upsample_scale = np.prod(upsample_factors)
        self.lvc_block_nums = len(upsample_factors)

        # define first convolution
        self.first_conv = torch.nn.Conv1d(
            in_channels, hidden_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
        )

        # define residual blocks
        self.lvc_blocks = torch.nn.ModuleList()
        cond_hop_length = 1
        for n in range(self.lvc_block_nums):
            cond_hop_length = cond_hop_length * upsample_factors[n]
            lvcb = LVCBlock(
                in_channels=hidden_channels,
                cond_channels=cond_channels,
                upsample_ratio=upsample_factors[n],
                conv_layers=lvc_layers_each_block,
                conv_kernel_size=lvc_kernel_size,
                cond_hop_length=cond_hop_length,
                kpnet_hidden_channels=kpnet_hidden_channels,
                kpnet_conv_size=kpnet_conv_size,
                kpnet_dropout=dropout,
            )
            self.lvc_blocks += [lvcb]

        # define output layers
        self.last_conv_layers = torch.nn.ModuleList(
            [
                torch.nn.Conv1d(
                    hidden_channels, out_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
                ),
            ]
        )

        # apply weight norm
        if use_weight_norm:
            self.apply_weight_norm()

    def forward(self, c):
        """Calculate forward propagation.
        Args:
            c (Tensor): Local conditioning auxiliary features (B, C ,T').
        Returns:
            Tensor: Output tensor (B, out_channels, T)
        """
        # random noise
        x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
        x = x.to(self.first_conv.bias.device)
        x = self.first_conv(x)

        for n in range(self.lvc_block_nums):
            x = self.lvc_blocks[n](x, c)

        # apply final layers
        for f in self.last_conv_layers:
            x = F.leaky_relu(x, LRELU_SLOPE)
            x = f(x)
        x = torch.tanh(x)
        return x

    def remove_weight_norm(self):
        """Remove weight normalization module from all of the layers."""

        def _remove_weight_norm(m):
            try:
                # print(f"Weight norm is removed from {m}.")
                parametrize.remove_parametrizations(m, "weight")
            except ValueError:  # this module didn't have weight norm
                return

        self.apply(_remove_weight_norm)

    def apply_weight_norm(self):
        """Apply weight normalization module from all of the layers."""

        def _apply_weight_norm(m):
            if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
                torch.nn.utils.parametrizations.weight_norm(m)
                # print(f"Weight norm is applied to {m}.")

        self.apply(_apply_weight_norm)

    @staticmethod
    def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
        assert layers % stacks == 0
        layers_per_cycle = layers // stacks
        dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
        return (kernel_size - 1) * sum(dilations) + 1

    @property
    def receptive_field_size(self):
        """Return receptive field size."""
        return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)

    @torch.no_grad()
    def inference(self, c):
        """Perform inference.
        Args:
            c (Tensor): Local conditioning auxiliary features :math:`(B, C, T)`.
        Returns:
            Tensor: Output tensor (T, out_channels)
        """
        x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
        x = x.to(self.first_conv.bias.device)

        c = c.to(next(self.parameters()))
        return self.forward(c)