Add `_set_model_args` to BaseModel

This commit is contained in:
Eren Gölge 2021-09-30 14:27:04 +00:00
parent 9a0d8fa027
commit 16b70be0dd
1 changed files with 12 additions and 4 deletions

View File

@ -22,6 +22,14 @@ class BaseModel(nn.Module, ABC):
- 1D tensors `batch x 1`
"""
def __init__(self, config: Coqpit):
super().__init__()
self._set_model_args(config)
def _set_model_args(self, config: Coqpit):
"""Set model arguments from the config. Override this."""
pass
@abstractmethod
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
"""Forward pass for the model mainly used in training.
@ -73,7 +81,7 @@ class BaseModel(nn.Module, ABC):
...
return outputs_dict, loss_dict
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
def train_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None:
"""Create visualizations and waveform examples for training.
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
@ -87,7 +95,7 @@ class BaseModel(nn.Module, ABC):
Returns:
Tuple[Dict, np.ndarray]: training plots and output waveform.
"""
return None, None
pass
@abstractmethod
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
@ -106,9 +114,9 @@ class BaseModel(nn.Module, ABC):
...
return outputs_dict, loss_dict
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None:
"""The same as `train_log()`"""
return None, None
pass
@abstractmethod
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None: