Update BaseTrainerModel

This commit is contained in:
Eren Gölge 2022-02-20 11:32:28 +01:00
parent b0cff949f5
commit c911729896
1 changed files with 54 additions and 39 deletions

View File

@ -1,39 +1,28 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple
import numpy as np
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
# pylint: skip-file
class BaseModel(nn.Module, ABC): class BaseTrainerModel(ABC, nn.Module):
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this. """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.
Notes on input/output tensor shapes:
Any input or output tensor of the model must be shaped as
- 3D tensors `batch x time x channels`
- 2D tensors `batch x channels`
- 1D tensors `batch x 1`
""" """
def __init__(self, config: Coqpit):
super().__init__()
@staticmethod @staticmethod
@abstractmethod
def init_from_config(config: Coqpit): def init_from_config(config: Coqpit):
"""Init the model from given config. """Init the model from given config.
Override this depending on your model. Override this depending on your model.
""" """
pass ...
@abstractmethod @abstractmethod
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict: def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
"""Forward pass for the model mainly used in training. """Forward ... for the model mainly used in training.
You can be flexible here and use different number of arguments and argument names since it is intended to be You can be flexible here and use different number of arguments and argument names since it is intended to be
used by `train_step()` without exposing it out of the model. used by `train_step()` without exposing it out of the model.
@ -51,7 +40,7 @@ class BaseModel(nn.Module, ABC):
@abstractmethod @abstractmethod
def inference(self, input: torch.Tensor, aux_input={}) -> Dict: def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
"""Forward pass for inference. """Forward ... for inference.
We don't use `*kwargs` since it is problematic with the TorchScript API. We don't use `*kwargs` since it is problematic with the TorchScript API.
@ -66,9 +55,25 @@ class BaseModel(nn.Module, ABC):
... ...
return outputs_dict return outputs_dict
def format_batch(self, batch: Dict) -> Dict:
"""Format batch returned by the data loader before sending it to the model.
If not implemented, model uses the batch as is.
Can be used for data augmentation, feature ectraction, etc.
"""
return batch
def format_batch_on_device(self, batch:Dict) -> Dict:
"""Format batch on device before sending it to the model.
If not implemented, model uses the batch as is.
Can be used for data augmentation, feature ectraction, etc.
"""
return batch
@abstractmethod @abstractmethod
def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
"""Perform a single training step. Run the model forward pass and compute losses. """Perform a single training step. Run the model forward ... and compute losses.
Args: Args:
batch (Dict): Input tensors. batch (Dict): Input tensors.
@ -96,11 +101,11 @@ class BaseModel(nn.Module, ABC):
Returns: Returns:
Tuple[Dict, np.ndarray]: training plots and output waveform. Tuple[Dict, np.ndarray]: training plots and output waveform.
""" """
pass ...
@abstractmethod @abstractmethod
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
"""Perform a single evaluation step. Run the model forward pass and compute losses. In most cases, you can """Perform a single evaluation step. Run the model forward ... and compute losses. In most cases, you can
call `train_step()` with no changes. call `train_step()` with no changes.
Args: Args:
@ -117,45 +122,55 @@ class BaseModel(nn.Module, ABC):
def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None: def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None:
"""The same as `train_log()`""" """The same as `train_log()`"""
pass ...
@abstractmethod @abstractmethod
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None: def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None:
"""Load a checkpoint and get ready for training or inference. """Load a checkpoint and get ready for training or inference.
Args: Args:
config (Coqpit): Model configuration. config (Coqpit): Model configuration.
checkpoint_path (str): Path to the model checkpoint file. checkpoint_path (str): Path to the model checkpoint file.
eval (bool, optional): If true, init model for inference else for training. Defaults to False. eval (bool, optional): If true, init model for inference else for training. Defaults to False.
strcit (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
""" """
... ...
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def init_from_config(config: Coqpit): def init_from_config(config: Coqpit, samples: List[Dict] = None, verbose=False) -> "BaseTrainerModel":
"""Init the model from given config. """Init the model from given config.
Override this depending on your model. Override this depending on your model.
""" """
pass ...
def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: @abstractmethod
"""Setup an return optimizer or optimizers.""" def get_data_loader(
pass self,
config: Coqpit,
assets: Dict,
is_eval: True,
data_items: List,
verbose: bool,
num_gpus: int):
...
def get_lr(self) -> Union[float, List[float]]: # def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
"""Return learning rate(s). # """Setup an return optimizer or optimizers."""
# ...
Returns: # def get_lr(self) -> Union[float, List[float]]:
Union[float, List[float]]: Model's initial learning rates. # """Return learning rate(s).
"""
pass
def get_scheduler(self, optimizer: torch.optim.Optimizer): # Returns:
pass # Union[float, List[float]]: Model's initial learning rates.
# """
# ...
def get_criterion(self): # def get_scheduler(self, optimizer: torch.optim.Optimizer):
pass # ...
def format_batch(self): # def get_criterion(self):
pass # ...