from typing import Dict, Union

import librosa
import torch
from torch import nn
from torch.nn import functional as F

from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss


class TorchSTFT(nn.Module):  # pylint: disable=abstract-method
    """TODO: Merge this with audio.py"""

    def __init__(
        self,
        n_fft,
        hop_length,
        win_length,
        pad_wav=False,
        window="hann_window",
        sample_rate=None,
        mel_fmin=0,
        mel_fmax=None,
        n_mels=80,
        use_mel=False,
    ):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.pad_wav = pad_wav
        self.sample_rate = sample_rate
        self.mel_fmin = mel_fmin
        self.mel_fmax = mel_fmax
        self.n_mels = n_mels
        self.use_mel = use_mel
        self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
        self.mel_basis = None
        if use_mel:
            self._build_mel_basis()

    def __call__(self, x):
        """Compute spectrogram frames by torch based stft.

        Args:
            x (Tensor): input waveform

        Returns:
            Tensor: spectrogram frames.

        Shapes:
            x: [B x T] or [B x 1 x T]
        """
        if x.ndim == 2:
            x = x.unsqueeze(1)
        if self.pad_wav:
            padding = int((self.n_fft - self.hop_length) / 2)
            x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
        # B x D x T x 2
        o = torch.stft(
            x.squeeze(1),
            self.n_fft,
            self.hop_length,
            self.win_length,
            self.window,
            center=True,
            pad_mode="reflect",  # compatible with audio.py
            normalized=False,
            onesided=True,
            return_complex=False,
        )
        M = o[:, :, :, 0]
        P = o[:, :, :, 1]
        S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
        if self.use_mel:
            S = torch.matmul(self.mel_basis.to(x), S)
        return S

    def _build_mel_basis(self):
        mel_basis = librosa.filters.mel(
            self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
        )
        self.mel_basis = torch.from_numpy(mel_basis).float()


#################################
# GENERATOR LOSSES
#################################


class STFTLoss(nn.Module):
    """STFT loss. Input generate and real waveforms are converted
    to spectrograms compared with L1 and Spectral convergence losses.
    It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""

    def __init__(self, n_fft, hop_length, win_length):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.stft = TorchSTFT(n_fft, hop_length, win_length)

    def forward(self, y_hat, y):
        y_hat_M = self.stft(y_hat)
        y_M = self.stft(y)
        # magnitude loss
        loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
        # spectral convergence loss
        loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro")
        return loss_mag, loss_sc


class MultiScaleSTFTLoss(torch.nn.Module):
    """Multi-scale STFT loss. Input generate and real waveforms are converted
    to spectrograms compared with L1 and Spectral convergence losses.
    It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""

    def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240)):
        super().__init__()
        self.loss_funcs = torch.nn.ModuleList()
        for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths):
            self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length))

    def forward(self, y_hat, y):
        N = len(self.loss_funcs)
        loss_sc = 0
        loss_mag = 0
        for f in self.loss_funcs:
            lm, lsc = f(y_hat, y)
            loss_mag += lm
            loss_sc += lsc
        loss_sc /= N
        loss_mag /= N
        return loss_mag, loss_sc


class L1SpecLoss(nn.Module):
    """L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf"""

    def __init__(
        self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True
    ):
        super().__init__()
        self.use_mel = use_mel
        self.stft = TorchSTFT(
            n_fft,
            hop_length,
            win_length,
            sample_rate=sample_rate,
            mel_fmin=mel_fmin,
            mel_fmax=mel_fmax,
            n_mels=n_mels,
            use_mel=use_mel,
        )

    def forward(self, y_hat, y):
        y_hat_M = self.stft(y_hat)
        y_M = self.stft(y)
        # magnitude loss
        loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
        return loss_mag


class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
    """Multiscale STFT loss for multi band model outputs.
    From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106"""

    # pylint: disable=no-self-use
    def forward(self, y_hat, y):
        y_hat = y_hat.view(-1, 1, y_hat.shape[2])
        y = y.view(-1, 1, y.shape[2])
        return super().forward(y_hat.squeeze(1), y.squeeze(1))


class MSEGLoss(nn.Module):
    """Mean Squared Generator Loss"""

    # pylint: disable=no-self-use
    def forward(self, score_real):
        loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape))
        return loss_fake


class HingeGLoss(nn.Module):
    """Hinge Discriminator Loss"""

    # pylint: disable=no-self-use
    def forward(self, score_real):
        # TODO: this might be wrong
        loss_fake = torch.mean(F.relu(1.0 - score_real))
        return loss_fake


##################################
# DISCRIMINATOR LOSSES
##################################


class MSEDLoss(nn.Module):
    """Mean Squared Discriminator Loss"""

    def __init__(
        self,
    ):
        super().__init__()
        self.loss_func = nn.MSELoss()

    # pylint: disable=no-self-use
    def forward(self, score_fake, score_real):
        loss_real = self.loss_func(score_real, score_real.new_ones(score_real.shape))
        loss_fake = self.loss_func(score_fake, score_fake.new_zeros(score_fake.shape))
        loss_d = loss_real + loss_fake
        return loss_d, loss_real, loss_fake


class HingeDLoss(nn.Module):
    """Hinge Discriminator Loss"""

    # pylint: disable=no-self-use
    def forward(self, score_fake, score_real):
        loss_real = torch.mean(F.relu(1.0 - score_real))
        loss_fake = torch.mean(F.relu(1.0 + score_fake))
        loss_d = loss_real + loss_fake
        return loss_d, loss_real, loss_fake


class MelganFeatureLoss(nn.Module):
    def __init__(
        self,
    ):
        super().__init__()
        self.loss_func = nn.L1Loss()

    # pylint: disable=no-self-use
    def forward(self, fake_feats, real_feats):
        loss_feats = 0
        num_feats = 0
        for idx, _ in enumerate(fake_feats):
            for fake_feat, real_feat in zip(fake_feats[idx], real_feats[idx]):
                loss_feats += self.loss_func(fake_feat, real_feat)
                num_feats += 1
        loss_feats = loss_feats / num_feats
        return loss_feats


#####################################
# LOSS WRAPPERS
#####################################


def _apply_G_adv_loss(scores_fake, loss_func):
    """Compute G adversarial loss function
    and normalize values"""
    adv_loss = 0
    if isinstance(scores_fake, list):
        for score_fake in scores_fake:
            fake_loss = loss_func(score_fake)
            adv_loss += fake_loss
        adv_loss /= len(scores_fake)
    else:
        fake_loss = loss_func(scores_fake)
        adv_loss = fake_loss
    return adv_loss


def _apply_D_loss(scores_fake, scores_real, loss_func):
    """Compute D loss func and normalize loss values"""
    loss = 0
    real_loss = 0
    fake_loss = 0
    if isinstance(scores_fake, list):
        # multi-scale loss
        for score_fake, score_real in zip(scores_fake, scores_real):
            total_loss, real_loss, fake_loss = loss_func(score_fake=score_fake, score_real=score_real)
            loss += total_loss
            real_loss += real_loss
            fake_loss += fake_loss
        # normalize loss values with number of scales
        loss /= len(scores_fake)
        real_loss /= len(scores_real)
        fake_loss /= len(scores_fake)
    else:
        # single scale loss
        total_loss, real_loss, fake_loss = loss_func(scores_fake, scores_real)
        loss = total_loss
    return loss, real_loss, fake_loss


##################################
# MODEL LOSSES
##################################


class GeneratorLoss(nn.Module):
    """Generator Loss Wrapper. Based on model configuration it sets a right set of loss functions and computes
    losses. It allows to experiment with different combinations of loss functions with different models by just
    changing configurations.

    Args:
        C (AttrDict): model configuration.
    """

    def __init__(self, C):
        super().__init__()
        assert not (
            C.use_mse_gan_loss and C.use_hinge_gan_loss
        ), " [!] Cannot use HingeGANLoss and MSEGANLoss together."

        self.use_stft_loss = C.use_stft_loss if "use_stft_loss" in C else False
        self.use_subband_stft_loss = C.use_subband_stft_loss if "use_subband_stft_loss" in C else False
        self.use_mse_gan_loss = C.use_mse_gan_loss if "use_mse_gan_loss" in C else False
        self.use_hinge_gan_loss = C.use_hinge_gan_loss if "use_hinge_gan_loss" in C else False
        self.use_feat_match_loss = C.use_feat_match_loss if "use_feat_match_loss" in C else False
        self.use_l1_spec_loss = C.use_l1_spec_loss if "use_l1_spec_loss" in C else False

        self.stft_loss_weight = C.stft_loss_weight if "stft_loss_weight" in C else 0.0
        self.subband_stft_loss_weight = C.subband_stft_loss_weight if "subband_stft_loss_weight" in C else 0.0
        self.mse_gan_loss_weight = C.mse_G_loss_weight if "mse_G_loss_weight" in C else 0.0
        self.hinge_gan_loss_weight = C.hinge_G_loss_weight if "hinde_G_loss_weight" in C else 0.0
        self.feat_match_loss_weight = C.feat_match_loss_weight if "feat_match_loss_weight" in C else 0.0
        self.l1_spec_loss_weight = C.l1_spec_loss_weight if "l1_spec_loss_weight" in C else 0.0

        if C.use_stft_loss:
            self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
        if C.use_subband_stft_loss:
            self.subband_stft_loss = MultiScaleSubbandSTFTLoss(**C.subband_stft_loss_params)
        if C.use_mse_gan_loss:
            self.mse_loss = MSEGLoss()
        if C.use_hinge_gan_loss:
            self.hinge_loss = HingeGLoss()
        if C.use_feat_match_loss:
            self.feat_match_loss = MelganFeatureLoss()
        if C.use_l1_spec_loss:
            assert C.audio["sample_rate"] == C.l1_spec_loss_params["sample_rate"]
            self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params)

    def forward(
        self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None
    ):
        gen_loss = 0
        adv_loss = 0
        return_dict = {}

        # STFT Loss
        if self.use_stft_loss:
            stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1))
            return_dict["G_stft_loss_mg"] = stft_loss_mg
            return_dict["G_stft_loss_sc"] = stft_loss_sc
            gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)

        # L1 Spec loss
        if self.use_l1_spec_loss:
            l1_spec_loss = self.l1_spec_loss(y_hat, y)
            return_dict["G_l1_spec_loss"] = l1_spec_loss
            gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss

        # subband STFT Loss
        if self.use_subband_stft_loss:
            subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
            return_dict["G_subband_stft_loss_mg"] = subband_stft_loss_mg
            return_dict["G_subband_stft_loss_sc"] = subband_stft_loss_sc
            gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)

        # multiscale MSE adversarial loss
        if self.use_mse_gan_loss and scores_fake is not None:
            mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss)
            return_dict["G_mse_fake_loss"] = mse_fake_loss
            adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss

        # multiscale Hinge adversarial loss
        if self.use_hinge_gan_loss and not scores_fake is not None:
            hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss)
            return_dict["G_hinge_fake_loss"] = hinge_fake_loss
            adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss

        # Feature Matching Loss
        if self.use_feat_match_loss and not feats_fake is None:
            feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
            return_dict["G_feat_match_loss"] = feat_match_loss
            adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
        return_dict["loss"] = gen_loss + adv_loss
        return_dict["G_gen_loss"] = gen_loss
        return_dict["G_adv_loss"] = adv_loss
        return return_dict


class DiscriminatorLoss(nn.Module):
    """Like ```GeneratorLoss```"""

    def __init__(self, C):
        super().__init__()
        assert not (
            C.use_mse_gan_loss and C.use_hinge_gan_loss
        ), " [!] Cannot use HingeGANLoss and MSEGANLoss together."

        self.use_mse_gan_loss = C.use_mse_gan_loss
        self.use_hinge_gan_loss = C.use_hinge_gan_loss

        if C.use_mse_gan_loss:
            self.mse_loss = MSEDLoss()
        if C.use_hinge_gan_loss:
            self.hinge_loss = HingeDLoss()

    def forward(self, scores_fake, scores_real):
        loss = 0
        return_dict = {}

        if self.use_mse_gan_loss:
            mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(
                scores_fake=scores_fake, scores_real=scores_real, loss_func=self.mse_loss
            )
            return_dict["D_mse_gan_loss"] = mse_D_loss
            return_dict["D_mse_gan_real_loss"] = mse_D_real_loss
            return_dict["D_mse_gan_fake_loss"] = mse_D_fake_loss
            loss += mse_D_loss

        if self.use_hinge_gan_loss:
            hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(
                scores_fake=scores_fake, scores_real=scores_real, loss_func=self.hinge_loss
            )
            return_dict["D_hinge_gan_loss"] = hinge_D_loss
            return_dict["D_hinge_gan_real_loss"] = hinge_D_real_loss
            return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss
            loss += hinge_D_loss

        return_dict["loss"] = loss
        return return_dict


class WaveRNNLoss(nn.Module):
    def __init__(self, wave_rnn_mode: Union[str, int]):
        super().__init__()
        if wave_rnn_mode == "mold":
            self.loss_func = discretized_mix_logistic_loss
        elif wave_rnn_mode == "gauss":
            self.loss_func = gaussian_loss
        elif isinstance(wave_rnn_mode, int):
            self.loss_func = torch.nn.CrossEntropyLoss()
        else:
            raise ValueError(" [!] Unknown mode for Wavernn.")

    def forward(self, y_hat, y) -> Dict:
        loss = self.loss_func(y_hat, y)
        return {"loss": loss}