# pylint: disable=W0613

import traceback
from pathlib import Path

try:
    import wandb
    from wandb import finish, init  # pylint: disable=W0611
except ImportError:
    wandb = None


class WandbLogger:
    def __init__(self, **kwargs):

        if not wandb:
            raise Exception("install wandb using `pip install wandb` to use WandbLogger")

        self.run = None
        self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
        self.model_name = self.run.config.model
        self.log_dict = {}

    def model_weights(self, model):
        layer_num = 1
        for name, param in model.named_parameters():
            if param.numel() == 1:
                self.dict_to_scalar("weights", {"layer{}-{}/value".format(layer_num, name): param.max()})
            else:
                self.dict_to_scalar("weights", {"layer{}-{}/max".format(layer_num, name): param.max()})
                self.dict_to_scalar("weights", {"layer{}-{}/min".format(layer_num, name): param.min()})
                self.dict_to_scalar("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()})
                self.dict_to_scalar("weights", {"layer{}-{}/std".format(layer_num, name): param.std()})
                self.log_dict["weights/layer{}-{}/param".format(layer_num, name)] = wandb.Histogram(param)
                self.log_dict["weights/layer{}-{}/grad".format(layer_num, name)] = wandb.Histogram(param.grad)
            layer_num += 1

    def dict_to_scalar(self, scope_name, stats):
        for key, value in stats.items():
            self.log_dict["{}/{}".format(scope_name, key)] = value

    def dict_to_figure(self, scope_name, figures):
        for key, value in figures.items():
            self.log_dict["{}/{}".format(scope_name, key)] = wandb.Image(value)

    def dict_to_audios(self, scope_name, audios, sample_rate):
        for key, value in audios.items():
            if value.dtype == "float16":
                value = value.astype("float32")
            try:
                self.log_dict["{}/{}".format(scope_name, key)] = wandb.Audio(value, sample_rate=sample_rate)
            except RuntimeError:
                traceback.print_exc()

    def log(self, log_dict, prefix="", flush=False):
        for key, value in log_dict.items():
            self.log_dict[prefix + key] = value
        if flush:  # for cases where you don't want to accumulate data
            self.flush()

    def train_step_stats(self, step, stats):
        self.dict_to_scalar(f"{self.model_name}_TrainIterStats", stats)

    def train_epoch_stats(self, step, stats):
        self.dict_to_scalar(f"{self.model_name}_TrainEpochStats", stats)

    def train_figures(self, step, figures):
        self.dict_to_figure(f"{self.model_name}_TrainFigures", figures)

    def train_audios(self, step, audios, sample_rate):
        self.dict_to_audios(f"{self.model_name}_TrainAudios", audios, sample_rate)

    def eval_stats(self, step, stats):
        self.dict_to_scalar(f"{self.model_name}_EvalStats", stats)

    def eval_figures(self, step, figures):
        self.dict_to_figure(f"{self.model_name}_EvalFigures", figures)

    def eval_audios(self, step, audios, sample_rate):
        self.dict_to_audios(f"{self.model_name}_EvalAudios", audios, sample_rate)

    def test_audios(self, step, audios, sample_rate):
        self.dict_to_audios(f"{self.model_name}_TestAudios", audios, sample_rate)

    def test_figures(self, step, figures):
        self.dict_to_figure(f"{self.model_name}_TestFigures", figures)

    def add_text(self, title, text, step):
        pass

    def flush(self):
        if self.run:
            wandb.log(self.log_dict)
        self.log_dict = {}

    def finish(self):
        if self.run:
            self.run.finish()

    def log_artifact(self, file_or_dir, name, artifact_type, aliases=None):
        if not self.run:
            return
        name = "_".join([self.run.id, name])
        artifact = wandb.Artifact(name, type=artifact_type)
        data_path = Path(file_or_dir)
        if data_path.is_dir():
            artifact.add_dir(str(data_path))
        elif data_path.is_file():
            artifact.add_file(str(data_path))

        self.run.log_artifact(artifact, aliases=aliases)