mirror of https://github.com/coqui-ai/TTS.git
142 lines
3.5 KiB
Python
142 lines
3.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import importlib
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Tuple, TypeVar
|
|
|
|
import torch
|
|
from coqpit import Coqpit
|
|
|
|
# DISTRIBUTED
|
|
from torch import nn
|
|
|
|
_DataLoader = TypeVar("_DataLoader")
|
|
|
|
|
|
@dataclass
|
|
class TrainingArgs(Coqpit):
|
|
"""Trainer arguments that are parsed externally (e.g. CLI)"""
|
|
|
|
continue_path: str = field(
|
|
default="",
|
|
metadata={
|
|
"help": "Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder."
|
|
},
|
|
)
|
|
restore_path: str = field(
|
|
default="",
|
|
metadata={
|
|
"help": "Path to a model checkpoit. Restore the model with the given checkpoint and start a new training."
|
|
},
|
|
)
|
|
best_path: str = field(
|
|
default="",
|
|
metadata={
|
|
"help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used"
|
|
},
|
|
)
|
|
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
|
|
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
|
|
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
|
|
|
|
|
|
# pylint: disable=import-outside-toplevel, too-many-public-methods
|
|
|
|
|
|
class TrainerAbstract(ABC):
|
|
|
|
@staticmethod
|
|
def _is_apex_available():
|
|
return importlib.util.find_spec("apex") is not None
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_model(*args, **kwargs) -> nn.Module:
|
|
pass
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer:
|
|
pass
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_scheduler(
|
|
config: Coqpit, optimizer: torch.optim.Optimizer
|
|
) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access
|
|
pass
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_criterion(config: Coqpit) -> nn.Module:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def restore_model(self, *args, **kwargs) -> Tuple:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_train_dataloader(self, *args, **kwargs) -> _DataLoader:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_eval_dataloder(self, *args, **kwargs) -> _DataLoader:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def format_batch(self, batch: List) -> Dict:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def train_epoch(self) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def eval_epoch(self) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def test_run(self) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def fit(self) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def save_best_model(self) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def on_epoch_start(self) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def on_epoch_end(self) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def on_train_step_start(self) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def on_train_step_end(self) -> None:
|
|
pass
|