# -*- coding: utf-8 -*- import importlib from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Dict, List, Tuple, TypeVar import torch from coqpit import Coqpit # DISTRIBUTED from torch import nn _DataLoader = TypeVar("_DataLoader") @dataclass class TrainingArgs(Coqpit): """Trainer arguments that are parsed externally (e.g. CLI)""" continue_path: str = field( default="", metadata={ "help": "Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder." }, ) restore_path: str = field( default="", metadata={ "help": "Path to a model checkpoit. Restore the model with the given checkpoint and start a new training." }, ) best_path: str = field( default="", metadata={ "help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used" }, ) config_path: str = field(default="", metadata={"help": "Path to the configuration file."}) rank: int = field(default=0, metadata={"help": "Process rank in distributed training."}) group_id: str = field(default="", metadata={"help": "Process group id in distributed training."}) # pylint: disable=import-outside-toplevel, too-many-public-methods class TrainerAbstract(ABC): @staticmethod def _is_apex_available(): return importlib.util.find_spec("apex") is not None @staticmethod @abstractmethod def get_model(*args, **kwargs) -> nn.Module: pass @staticmethod @abstractmethod def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer: pass @staticmethod @abstractmethod def get_scheduler( config: Coqpit, optimizer: torch.optim.Optimizer ) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access pass @staticmethod @abstractmethod def get_criterion(config: Coqpit) -> nn.Module: pass @abstractmethod def restore_model(self, *args, **kwargs) -> Tuple: pass @abstractmethod def get_train_dataloader(self, *args, **kwargs) -> _DataLoader: pass @abstractmethod def get_eval_dataloder(self, *args, **kwargs) -> _DataLoader: pass @abstractmethod def format_batch(self, batch: List) -> Dict: pass @abstractmethod def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: pass @abstractmethod def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]: pass @abstractmethod def train_epoch(self) -> None: pass @abstractmethod def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]: pass @abstractmethod def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]: pass @abstractmethod def eval_epoch(self) -> None: pass @abstractmethod def test_run(self) -> None: pass @abstractmethod def fit(self) -> None: pass @abstractmethod def save_best_model(self) -> None: pass @abstractmethod def on_epoch_start(self) -> None: pass @abstractmethod def on_epoch_end(self) -> None: pass @abstractmethod def on_train_step_start(self) -> None: pass @abstractmethod def on_train_step_end(self) -> None: pass