mirror of https://github.com/coqui-ai/TTS.git
Update BaseTrainerModel
This commit is contained in:
parent
b0cff949f5
commit
c911729896
93
TTS/model.py
93
TTS/model.py
|
@ -1,39 +1,28 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
# pylint: skip-file
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(nn.Module, ABC):
|
class BaseTrainerModel(ABC, nn.Module):
|
||||||
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.
|
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.
|
||||||
|
|
||||||
Notes on input/output tensor shapes:
|
|
||||||
Any input or output tensor of the model must be shaped as
|
|
||||||
|
|
||||||
- 3D tensors `batch x time x channels`
|
|
||||||
- 2D tensors `batch x channels`
|
|
||||||
- 1D tensors `batch x 1`
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: Coqpit):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
def init_from_config(config: Coqpit):
|
def init_from_config(config: Coqpit):
|
||||||
"""Init the model from given config.
|
"""Init the model from given config.
|
||||||
|
|
||||||
Override this depending on your model.
|
Override this depending on your model.
|
||||||
"""
|
"""
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
|
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
|
||||||
"""Forward pass for the model mainly used in training.
|
"""Forward ... for the model mainly used in training.
|
||||||
|
|
||||||
You can be flexible here and use different number of arguments and argument names since it is intended to be
|
You can be flexible here and use different number of arguments and argument names since it is intended to be
|
||||||
used by `train_step()` without exposing it out of the model.
|
used by `train_step()` without exposing it out of the model.
|
||||||
|
@ -51,7 +40,7 @@ class BaseModel(nn.Module, ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
|
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
|
||||||
"""Forward pass for inference.
|
"""Forward ... for inference.
|
||||||
|
|
||||||
We don't use `*kwargs` since it is problematic with the TorchScript API.
|
We don't use `*kwargs` since it is problematic with the TorchScript API.
|
||||||
|
|
||||||
|
@ -66,9 +55,25 @@ class BaseModel(nn.Module, ABC):
|
||||||
...
|
...
|
||||||
return outputs_dict
|
return outputs_dict
|
||||||
|
|
||||||
|
def format_batch(self, batch: Dict) -> Dict:
|
||||||
|
"""Format batch returned by the data loader before sending it to the model.
|
||||||
|
|
||||||
|
If not implemented, model uses the batch as is.
|
||||||
|
Can be used for data augmentation, feature ectraction, etc.
|
||||||
|
"""
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def format_batch_on_device(self, batch:Dict) -> Dict:
|
||||||
|
"""Format batch on device before sending it to the model.
|
||||||
|
|
||||||
|
If not implemented, model uses the batch as is.
|
||||||
|
Can be used for data augmentation, feature ectraction, etc.
|
||||||
|
"""
|
||||||
|
return batch
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||||
"""Perform a single training step. Run the model forward pass and compute losses.
|
"""Perform a single training step. Run the model forward ... and compute losses.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch (Dict): Input tensors.
|
batch (Dict): Input tensors.
|
||||||
|
@ -96,11 +101,11 @@ class BaseModel(nn.Module, ABC):
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||||
"""
|
"""
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||||
"""Perform a single evaluation step. Run the model forward pass and compute losses. In most cases, you can
|
"""Perform a single evaluation step. Run the model forward ... and compute losses. In most cases, you can
|
||||||
call `train_step()` with no changes.
|
call `train_step()` with no changes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -117,45 +122,55 @@ class BaseModel(nn.Module, ABC):
|
||||||
|
|
||||||
def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None:
|
def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None:
|
||||||
"""The same as `train_log()`"""
|
"""The same as `train_log()`"""
|
||||||
pass
|
...
|
||||||
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None:
|
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None:
|
||||||
"""Load a checkpoint and get ready for training or inference.
|
"""Load a checkpoint and get ready for training or inference.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model configuration.
|
config (Coqpit): Model configuration.
|
||||||
checkpoint_path (str): Path to the model checkpoint file.
|
checkpoint_path (str): Path to the model checkpoint file.
|
||||||
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
||||||
|
strcit (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def init_from_config(config: Coqpit):
|
def init_from_config(config: Coqpit, samples: List[Dict] = None, verbose=False) -> "BaseTrainerModel":
|
||||||
"""Init the model from given config.
|
"""Init the model from given config.
|
||||||
|
|
||||||
Override this depending on your model.
|
Override this depending on your model.
|
||||||
"""
|
"""
|
||||||
pass
|
...
|
||||||
|
|
||||||
def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
|
@abstractmethod
|
||||||
"""Setup an return optimizer or optimizers."""
|
def get_data_loader(
|
||||||
pass
|
self,
|
||||||
|
config: Coqpit,
|
||||||
|
assets: Dict,
|
||||||
|
is_eval: True,
|
||||||
|
data_items: List,
|
||||||
|
verbose: bool,
|
||||||
|
num_gpus: int):
|
||||||
|
...
|
||||||
|
|
||||||
def get_lr(self) -> Union[float, List[float]]:
|
# def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
|
||||||
"""Return learning rate(s).
|
# """Setup an return optimizer or optimizers."""
|
||||||
|
# ...
|
||||||
|
|
||||||
Returns:
|
# def get_lr(self) -> Union[float, List[float]]:
|
||||||
Union[float, List[float]]: Model's initial learning rates.
|
# """Return learning rate(s).
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_scheduler(self, optimizer: torch.optim.Optimizer):
|
# Returns:
|
||||||
pass
|
# Union[float, List[float]]: Model's initial learning rates.
|
||||||
|
# """
|
||||||
|
# ...
|
||||||
|
|
||||||
def get_criterion(self):
|
# def get_scheduler(self, optimizer: torch.optim.Optimizer):
|
||||||
pass
|
# ...
|
||||||
|
|
||||||
def format_batch(self):
|
# def get_criterion(self):
|
||||||
pass
|
# ...
|
||||||
|
|
Loading…
Reference in New Issue