import tensorflow as tf


class ReflectionPad1d(tf.keras.layers.Layer):
    def __init__(self, padding):
        super(ReflectionPad1d, self).__init__()
        self.padding = padding

    def call(self, x):
        return tf.pad(x, [[0, 0], [self.padding, self.padding], [0, 0], [0, 0]], "REFLECT")


class ResidualStack(tf.keras.layers.Layer):
    def __init__(self, channels, num_res_blocks, kernel_size, name):
        super(ResidualStack, self).__init__(name=name)

        assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd."
        base_padding = (kernel_size - 1) // 2

        self.blocks = []
        num_layers = 2
        for idx in range(num_res_blocks):
            layer_kernel_size = kernel_size
            layer_dilation = layer_kernel_size**idx
            layer_padding = base_padding * layer_dilation
            block = [
                tf.keras.layers.LeakyReLU(0.2),
                ReflectionPad1d(layer_padding),
                tf.keras.layers.Conv2D(filters=channels,
                                       kernel_size=(kernel_size, 1),
                                       dilation_rate=(layer_dilation, 1),
                                       use_bias=True,
                                       padding='valid',
                                       name=f'blocks.{idx}.{num_layers}'),
                tf.keras.layers.LeakyReLU(0.2),
                tf.keras.layers.Conv2D(filters=channels,
                                       kernel_size=(1, 1),
                                       use_bias=True,
                                       name=f'blocks.{idx}.{num_layers + 2}')
            ]
            self.blocks.append(block)
        self.shortcuts = [
            tf.keras.layers.Conv2D(channels,
                                   kernel_size=1,
                                   use_bias=True,
                                   name=f'shortcuts.{i}')
            for i in range(num_res_blocks)
        ]

    def call(self, x):
        for block, shortcut in zip(self.blocks, self.shortcuts):
            res = shortcut(x)
            for layer in block:
                x = layer(x)
            x += res
        return x