From 16b70be0dd1f0e15d3746f2fa7e51692e3a552f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 30 Sep 2021 14:27:04 +0000 Subject: [PATCH] Add `_set_model_args` to BaseModel --- TTS/model.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/TTS/model.py b/TTS/model.py index cfd1ec62..e34846bb 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -22,6 +22,14 @@ class BaseModel(nn.Module, ABC): - 1D tensors `batch x 1` """ + def __init__(self, config: Coqpit): + super().__init__() + self._set_model_args(config) + + def _set_model_args(self, config: Coqpit): + """Set model arguments from the config. Override this.""" + pass + @abstractmethod def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict: """Forward pass for the model mainly used in training. @@ -73,7 +81,7 @@ class BaseModel(nn.Module, ABC): ... return outputs_dict, loss_dict - def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + def train_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None: """Create visualizations and waveform examples for training. For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to @@ -87,7 +95,7 @@ class BaseModel(nn.Module, ABC): Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ - return None, None + pass @abstractmethod def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: @@ -106,9 +114,9 @@ class BaseModel(nn.Module, ABC): ... return outputs_dict, loss_dict - def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None: """The same as `train_log()`""" - return None, None + pass @abstractmethod def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None: