Add `_set_model_args` to BaseModel

This commit is contained in:
Eren Gölge 2021-09-30 14:27:04 +00:00
parent 9a0d8fa027
commit 16b70be0dd
1 changed files with 12 additions and 4 deletions

View File

@ -22,6 +22,14 @@ class BaseModel(nn.Module, ABC):
- 1D tensors `batch x 1` - 1D tensors `batch x 1`
""" """
def __init__(self, config: Coqpit):
super().__init__()
self._set_model_args(config)
def _set_model_args(self, config: Coqpit):
"""Set model arguments from the config. Override this."""
pass
@abstractmethod @abstractmethod
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict: def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
"""Forward pass for the model mainly used in training. """Forward pass for the model mainly used in training.
@ -73,7 +81,7 @@ class BaseModel(nn.Module, ABC):
... ...
return outputs_dict, loss_dict return outputs_dict, loss_dict
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: def train_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None:
"""Create visualizations and waveform examples for training. """Create visualizations and waveform examples for training.
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
@ -87,7 +95,7 @@ class BaseModel(nn.Module, ABC):
Returns: Returns:
Tuple[Dict, np.ndarray]: training plots and output waveform. Tuple[Dict, np.ndarray]: training plots and output waveform.
""" """
return None, None pass
@abstractmethod @abstractmethod
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
@ -106,9 +114,9 @@ class BaseModel(nn.Module, ABC):
... ...
return outputs_dict, loss_dict return outputs_dict, loss_dict
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None:
"""The same as `train_log()`""" """The same as `train_log()`"""
return None, None pass
@abstractmethod @abstractmethod
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None: def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None: