import torch
import torch.nn as nn
import torch.nn.functional as F


class GST(nn.Module):
    """Global Style Token Module for factorizing prosody in speech.

    See https://arxiv.org/pdf/1803.09017"""

    def __init__(self, num_mel, num_heads, num_style_tokens, embedding_dim):
        super().__init__()
        self.encoder = ReferenceEncoder(num_mel, embedding_dim)
        self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens,
                                                 embedding_dim)

    def forward(self, inputs):
        enc_out = self.encoder(inputs)
        style_embed = self.style_token_layer(enc_out)

        return style_embed


class ReferenceEncoder(nn.Module):
    """NN module creating a fixed size prosody embedding from a spectrogram.

    inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
    outputs: [batch_size, embedding_dim]
    """

    def __init__(self, num_mel, embedding_dim):

        super().__init__()
        self.num_mel = num_mel
        filters = [1] + [32, 32, 64, 64, 128, 128]
        num_layers = len(filters) - 1
        convs = [
            nn.Conv2d(
                in_channels=filters[i],
                out_channels=filters[i + 1],
                kernel_size=(3, 3),
                stride=(2, 2),
                padding=(1, 1)) for i in range(num_layers)
        ]
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList([
            nn.BatchNorm2d(num_features=filter_size)
            for filter_size in filters[1:]
        ])

        post_conv_height = self.calculate_post_conv_height(
            num_mel, 3, 2, 1, num_layers)
        self.recurrence = nn.GRU(
            input_size=filters[-1] * post_conv_height,
            hidden_size=embedding_dim // 2,
            batch_first=True)

    def forward(self, inputs):
        batch_size = inputs.size(0)
        x = inputs.view(batch_size, 1, -1, self.num_mel)
        # x: 4D tensor [batch_size, num_channels==1, num_frames, num_mel]
        for conv, bn in zip(self.convs, self.bns):
            x = conv(x)
            x = bn(x)
            x = F.relu(x)

        x = x.transpose(1, 2)
        # x: 4D tensor [batch_size, post_conv_width,
        #               num_channels==128, post_conv_height]
        post_conv_width = x.size(1)
        x = x.contiguous().view(batch_size, post_conv_width, -1)
        # x: 3D tensor [batch_size, post_conv_width,
        #               num_channels*post_conv_height]
        self.recurrence.flatten_parameters()
        memory, out = self.recurrence(x)
        # out: 3D tensor [seq_len==1, batch_size, encoding_size=128]

        return out.squeeze(0)

    @staticmethod
    def calculate_post_conv_height(height, kernel_size, stride, pad,
                                   n_convs):
        """Height of spec after n convolutions with fixed kernel/stride/pad."""
        for _ in range(n_convs):
            height = (height - kernel_size + 2 * pad) // stride + 1
        return height


class StyleTokenLayer(nn.Module):
    """NN Module attending to style tokens based on prosody encodings."""

    def __init__(self, num_heads, num_style_tokens,
                 embedding_dim):
        super().__init__()
        self.query_dim = embedding_dim // 2
        self.key_dim = embedding_dim // num_heads
        self.style_tokens = nn.Parameter(
            torch.FloatTensor(num_style_tokens, self.key_dim))
        nn.init.orthogonal_(self.style_tokens)
        self.attention = MultiHeadAttention(
            query_dim=self.query_dim,
            key_dim=self.key_dim,
            num_units=embedding_dim,
            num_heads=num_heads)

    def forward(self, inputs):
        batch_size = inputs.size(0)
        prosody_encoding = inputs.unsqueeze(1)
        # prosody_encoding: 3D tensor [batch_size, 1, encoding_size==128]
        tokens = torch.tanh(self.style_tokens) \
            .unsqueeze(0) \
            .expand(batch_size, -1, -1)
        # tokens: 3D tensor [batch_size, num tokens, token embedding size]
        style_embed = self.attention(prosody_encoding, tokens)

        return style_embed


class MultiHeadAttention(nn.Module):
    '''
    input:
        query --- [N, T_q, query_dim]
        key --- [N, T_k, key_dim]
    output:
        out --- [N, T_q, num_units]
    '''

    def __init__(self, query_dim, key_dim, num_units, num_heads):

        super().__init__()
        self.num_units = num_units
        self.num_heads = num_heads
        self.key_dim = key_dim

        self.W_query = nn.Linear(
            in_features=query_dim, out_features=num_units, bias=False)
        self.W_key = nn.Linear(
            in_features=key_dim, out_features=num_units, bias=False)
        self.W_value = nn.Linear(
            in_features=key_dim, out_features=num_units, bias=False)

    def forward(self, query, key):
        queries = self.W_query(query)  # [N, T_q, num_units]
        keys = self.W_key(key)  # [N, T_k, num_units]
        values = self.W_value(key)

        split_size = self.num_units // self.num_heads
        queries = torch.stack(
            torch.split(queries, split_size, dim=2),
            dim=0)  # [h, N, T_q, num_units/h]
        keys = torch.stack(
            torch.split(keys, split_size, dim=2),
            dim=0)  # [h, N, T_k, num_units/h]
        values = torch.stack(
            torch.split(values, split_size, dim=2),
            dim=0)  # [h, N, T_k, num_units/h]

        # score = softmax(QK^T / (d_k ** 0.5))
        scores = torch.matmul(queries, keys.transpose(2, 3))  # [h, N, T_q, T_k]
        scores = scores / (self.key_dim**0.5)
        scores = F.softmax(scores, dim=3)

        # out = score * V
        out = torch.matmul(scores, values)  # [h, N, T_q, num_units/h]
        out = torch.cat(
            torch.split(out, 1, dim=0),
            dim=3).squeeze(0)  # [N, T_q, num_units]

        return out