import torch
from torch import nn
from ..generic.normalization import LayerNorm


class MDNBlock(nn.Module):
    """Mixture of Density Network implementation
    https://arxiv.org/pdf/2003.01950.pdf
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.out_channels = out_channels
        self.mdn = nn.Sequential(nn.Conv1d(in_channels, in_channels, 1),
                                 LayerNorm(in_channels),
                                 nn.ReLU(),
                                 nn.Dropout(0.1),
                                 nn.Conv1d(in_channels, out_channels, 1))

    def forward(self, x):
        mu_sigma = self.mdn(x)
        # TODO: check this sigmoid
        # mu = torch.sigmoid(mu_sigma[:, :self.out_channels//2, :])
        mu = mu_sigma[:, :self.out_channels//2, :]
        log_sigma = mu_sigma[:, self.out_channels//2:, :]
        return mu, log_sigma