Update BaseTrainerModel

This commit is contained in:
Eren Gölge 2022-02-20 11:32:28 +01:00
parent b0cff949f5
commit c911729896
1 changed files with 54 additions and 39 deletions

View File

@ -1,39 +1,28 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple
import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
# pylint: skip-file
class BaseModel(nn.Module, ABC):
class BaseTrainerModel(ABC, nn.Module):
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.
Notes on input/output tensor shapes:
Any input or output tensor of the model must be shaped as
- 3D tensors `batch x time x channels`
- 2D tensors `batch x channels`
- 1D tensors `batch x 1`
"""
def __init__(self, config: Coqpit):
super().__init__()
@staticmethod
@abstractmethod
def init_from_config(config: Coqpit):
"""Init the model from given config.
Override this depending on your model.
"""
pass
...
@abstractmethod
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
"""Forward pass for the model mainly used in training.
"""Forward ... for the model mainly used in training.
You can be flexible here and use different number of arguments and argument names since it is intended to be
used by `train_step()` without exposing it out of the model.
@ -51,7 +40,7 @@ class BaseModel(nn.Module, ABC):
@abstractmethod
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
"""Forward pass for inference.
"""Forward ... for inference.
We don't use `*kwargs` since it is problematic with the TorchScript API.
@ -66,9 +55,25 @@ class BaseModel(nn.Module, ABC):
...
return outputs_dict
def format_batch(self, batch: Dict) -> Dict:
"""Format batch returned by the data loader before sending it to the model.
If not implemented, model uses the batch as is.
Can be used for data augmentation, feature ectraction, etc.
"""
return batch
def format_batch_on_device(self, batch:Dict) -> Dict:
"""Format batch on device before sending it to the model.
If not implemented, model uses the batch as is.
Can be used for data augmentation, feature ectraction, etc.
"""
return batch
@abstractmethod
def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
"""Perform a single training step. Run the model forward pass and compute losses.
"""Perform a single training step. Run the model forward ... and compute losses.
Args:
batch (Dict): Input tensors.
@ -96,11 +101,11 @@ class BaseModel(nn.Module, ABC):
Returns:
Tuple[Dict, np.ndarray]: training plots and output waveform.
"""
pass
...
@abstractmethod
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
"""Perform a single evaluation step. Run the model forward pass and compute losses. In most cases, you can
"""Perform a single evaluation step. Run the model forward ... and compute losses. In most cases, you can
call `train_step()` with no changes.
Args:
@ -117,45 +122,55 @@ class BaseModel(nn.Module, ABC):
def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None:
"""The same as `train_log()`"""
pass
...
@abstractmethod
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None:
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None:
"""Load a checkpoint and get ready for training or inference.
Args:
config (Coqpit): Model configuration.
checkpoint_path (str): Path to the model checkpoint file.
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
strcit (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
"""
...
@staticmethod
@abstractmethod
def init_from_config(config: Coqpit):
def init_from_config(config: Coqpit, samples: List[Dict] = None, verbose=False) -> "BaseTrainerModel":
"""Init the model from given config.
Override this depending on your model.
"""
pass
...
def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
"""Setup an return optimizer or optimizers."""
pass
@abstractmethod
def get_data_loader(
self,
config: Coqpit,
assets: Dict,
is_eval: True,
data_items: List,
verbose: bool,
num_gpus: int):
...
def get_lr(self) -> Union[float, List[float]]:
"""Return learning rate(s).
# def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
# """Setup an return optimizer or optimizers."""
# ...
Returns:
Union[float, List[float]]: Model's initial learning rates.
"""
pass
# def get_lr(self) -> Union[float, List[float]]:
# """Return learning rate(s).
def get_scheduler(self, optimizer: torch.optim.Optimizer):
pass
# Returns:
# Union[float, List[float]]: Model's initial learning rates.
# """
# ...
def get_criterion(self):
pass
# def get_scheduler(self, optimizer: torch.optim.Optimizer):
# ...
def format_batch(self):
pass
# def get_criterion(self):
# ...