coqui-tts/TTS/trainer.py

142 lines
3.5 KiB
Python

# -*- coding: utf-8 -*-
import importlib
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, TypeVar
import torch
from coqpit import Coqpit
# DISTRIBUTED
from torch import nn
_DataLoader = TypeVar("_DataLoader")
@dataclass
class TrainingArgs(Coqpit):
"""Trainer arguments that are parsed externally (e.g. CLI)"""
continue_path: str = field(
default="",
metadata={
"help": "Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder."
},
)
restore_path: str = field(
default="",
metadata={
"help": "Path to a model checkpoit. Restore the model with the given checkpoint and start a new training."
},
)
best_path: str = field(
default="",
metadata={
"help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used"
},
)
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
# pylint: disable=import-outside-toplevel, too-many-public-methods
class TrainerAbstract(ABC):
@staticmethod
def _is_apex_available():
return importlib.util.find_spec("apex") is not None
@staticmethod
@abstractmethod
def get_model(*args, **kwargs) -> nn.Module:
pass
@staticmethod
@abstractmethod
def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer:
pass
@staticmethod
@abstractmethod
def get_scheduler(
config: Coqpit, optimizer: torch.optim.Optimizer
) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access
pass
@staticmethod
@abstractmethod
def get_criterion(config: Coqpit) -> nn.Module:
pass
@abstractmethod
def restore_model(self, *args, **kwargs) -> Tuple:
pass
@abstractmethod
def get_train_dataloader(self, *args, **kwargs) -> _DataLoader:
pass
@abstractmethod
def get_eval_dataloder(self, *args, **kwargs) -> _DataLoader:
pass
@abstractmethod
def format_batch(self, batch: List) -> Dict:
pass
@abstractmethod
def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
pass
@abstractmethod
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
pass
@abstractmethod
def train_epoch(self) -> None:
pass
@abstractmethod
def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]:
pass
@abstractmethod
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
pass
@abstractmethod
def eval_epoch(self) -> None:
pass
@abstractmethod
def test_run(self) -> None:
pass
@abstractmethod
def fit(self) -> None:
pass
@abstractmethod
def save_best_model(self) -> None:
pass
@abstractmethod
def on_epoch_start(self) -> None:
pass
@abstractmethod
def on_epoch_end(self) -> None:
pass
@abstractmethod
def on_train_step_start(self) -> None:
pass
@abstractmethod
def on_train_step_end(self) -> None:
pass