From c9117298960d83e1b1e04d23eca4cda1b1c8bc20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:32:28 +0100 Subject: [PATCH] Update BaseTrainerModel --- TTS/model.py | 93 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 54 insertions(+), 39 deletions(-) diff --git a/TTS/model.py b/TTS/model.py index 6ce11e63..d7bd4f9f 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -1,39 +1,28 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple -import numpy as np import torch from coqpit import Coqpit from torch import nn -# pylint: skip-file -class BaseModel(nn.Module, ABC): +class BaseTrainerModel(ABC, nn.Module): """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this. - - Notes on input/output tensor shapes: - Any input or output tensor of the model must be shaped as - - - 3D tensors `batch x time x channels` - - 2D tensors `batch x channels` - - 1D tensors `batch x 1` """ - def __init__(self, config: Coqpit): - super().__init__() - @staticmethod + @abstractmethod def init_from_config(config: Coqpit): """Init the model from given config. Override this depending on your model. """ - pass + ... @abstractmethod def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict: - """Forward pass for the model mainly used in training. + """Forward ... for the model mainly used in training. You can be flexible here and use different number of arguments and argument names since it is intended to be used by `train_step()` without exposing it out of the model. @@ -51,7 +40,7 @@ class BaseModel(nn.Module, ABC): @abstractmethod def inference(self, input: torch.Tensor, aux_input={}) -> Dict: - """Forward pass for inference. + """Forward ... for inference. We don't use `*kwargs` since it is problematic with the TorchScript API. @@ -66,9 +55,25 @@ class BaseModel(nn.Module, ABC): ... return outputs_dict + def format_batch(self, batch: Dict) -> Dict: + """Format batch returned by the data loader before sending it to the model. + + If not implemented, model uses the batch as is. + Can be used for data augmentation, feature ectraction, etc. + """ + return batch + + def format_batch_on_device(self, batch:Dict) -> Dict: + """Format batch on device before sending it to the model. + + If not implemented, model uses the batch as is. + Can be used for data augmentation, feature ectraction, etc. + """ + return batch + @abstractmethod def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: - """Perform a single training step. Run the model forward pass and compute losses. + """Perform a single training step. Run the model forward ... and compute losses. Args: batch (Dict): Input tensors. @@ -96,11 +101,11 @@ class BaseModel(nn.Module, ABC): Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ - pass + ... @abstractmethod def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: - """Perform a single evaluation step. Run the model forward pass and compute losses. In most cases, you can + """Perform a single evaluation step. Run the model forward ... and compute losses. In most cases, you can call `train_step()` with no changes. Args: @@ -117,45 +122,55 @@ class BaseModel(nn.Module, ABC): def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None: """The same as `train_log()`""" - pass + ... + @abstractmethod - def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None: + def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None: """Load a checkpoint and get ready for training or inference. Args: config (Coqpit): Model configuration. checkpoint_path (str): Path to the model checkpoint file. eval (bool, optional): If true, init model for inference else for training. Defaults to False. + strcit (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. """ ... @staticmethod @abstractmethod - def init_from_config(config: Coqpit): + def init_from_config(config: Coqpit, samples: List[Dict] = None, verbose=False) -> "BaseTrainerModel": """Init the model from given config. Override this depending on your model. """ - pass + ... - def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: - """Setup an return optimizer or optimizers.""" - pass + @abstractmethod + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: True, + data_items: List, + verbose: bool, + num_gpus: int): + ... - def get_lr(self) -> Union[float, List[float]]: - """Return learning rate(s). + # def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: + # """Setup an return optimizer or optimizers.""" + # ... - Returns: - Union[float, List[float]]: Model's initial learning rates. - """ - pass + # def get_lr(self) -> Union[float, List[float]]: + # """Return learning rate(s). - def get_scheduler(self, optimizer: torch.optim.Optimizer): - pass + # Returns: + # Union[float, List[float]]: Model's initial learning rates. + # """ + # ... - def get_criterion(self): - pass + # def get_scheduler(self, optimizer: torch.optim.Optimizer): + # ... - def format_batch(self): - pass + # def get_criterion(self): + # ...