import numpy as np
from torch import nn
from torch.nn.utils import weight_norm


class MelganDiscriminator(nn.Module):
    def __init__(self,
                 in_channels=1,
                 out_channels=1,
                 kernel_sizes=(5, 3),
                 base_channels=16,
                 max_channels=1024,
                 downsample_factors=(4, 4, 4, 4)):
        super(MelganDiscriminator, self).__init__()
        self.layers = nn.ModuleList()

        layer_kernel_size = np.prod(kernel_sizes)
        layer_padding = (layer_kernel_size - 1) // 2

        # initial layer
        self.layers += [
            nn.Sequential(
                nn.ReflectionPad1d(layer_padding),
                weight_norm(
                    nn.Conv1d(in_channels,
                              base_channels,
                              layer_kernel_size,
                              stride=1)), nn.LeakyReLU(0.2, inplace=True))
        ]

        # downsampling layers
        layer_in_channels = base_channels
        for downsample_factor in downsample_factors:
            layer_out_channels = min(layer_in_channels * downsample_factor,
                                     max_channels)
            layer_kernel_size = downsample_factor * 10 + 1
            layer_padding = (layer_kernel_size - 1) // 2
            layer_groups = layer_in_channels // 4
            self.layers += [
                nn.Sequential(
                    weight_norm(
                        nn.Conv1d(layer_in_channels,
                                  layer_out_channels,
                                  kernel_size=layer_kernel_size,
                                  stride=downsample_factor,
                                  padding=layer_padding,
                                  groups=layer_groups)),
                    nn.LeakyReLU(0.2, inplace=True))
            ]
            layer_in_channels = layer_out_channels

        # last 2 layers
        layer_padding1 = (kernel_sizes[0] - 1) // 2
        layer_padding2 = (kernel_sizes[1] - 1) // 2
        self.layers += [
            nn.Sequential(
                weight_norm(
                    nn.Conv1d(layer_out_channels,
                              layer_out_channels,
                              kernel_size=kernel_sizes[0],
                              stride=1,
                              padding=layer_padding1)),
                nn.LeakyReLU(0.2, inplace=True),
            ),
            weight_norm(
                nn.Conv1d(layer_out_channels,
                          out_channels,
                          kernel_size=kernel_sizes[1],
                          stride=1,
                          padding=layer_padding2)),
        ]

    def forward(self, x):
        feats = []
        for layer in self.layers:
            x = layer(x)
            feats.append(x)
        return x, feats