from coqpit import MISSING
from dataclasses import dataclass, field, asdict
from typing import List
from TTS.config.shared_configs import BaseTrainingConfig, BaseAudioConfig, BaseDatasetConfig


@dataclass
class SpeakerEncoderConfig(BaseTrainingConfig):
    """Defines parameters for Speaker Encoder model."""

    model: str = "speaker_encoder"
    audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
    datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])

    # model params
    model_params: dict = field(default_factory=lambda: {
        "input_dim": 40,
        "proj_dim": 256,
        "lstm_dim": 768,
        "num_lstm_layers": 3,
        "use_lstm_with_projection": True
    })

    storage: dict = field(default_factory=lambda:{
        "sample_from_storage_p": 0.66,  # the probability with which we'll sample from the DataSet in-memory storage
        "storage_size": 15,   # the size of the in-memory storage with respect to a single batch
        "additive_noise": 1e-5   # add very small gaussian noise to the data in order to increase robustness
    })

    # training params
    max_train_step: int = 1000  # end training when number of training steps reaches this value.
    loss: str = 'angleproto'
    grad_clip: float = 3.0
    lr: float = 0.0001
    lr_decay: bool = False
    warmup_steps: int = 4000
    wd: float = 1e-6

    # logging params
    tb_model_param_stats: bool = False
    steps_plot_stats: int = 10
    checkpoint: bool = True
    save_step: int = 1000
    print_step: int = 20

    # data loader
    num_speakers_in_batch: int = MISSING
    num_utters_per_speaker: int = MISSING
    num_loader_workers: int = MISSING

    def check_values(self):
        super().check_values()
        c = asdict(self)
        assert c['model_params']['input_dim'] == self.audio.num_mels, " [!] model input dimendion must be equal to melspectrogram dimension."