diff --git a/TTS/model.py b/TTS/model.py index cfd1ec62..e34846bb 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -22,6 +22,14 @@ class BaseModel(nn.Module, ABC): - 1D tensors `batch x 1` """ + def __init__(self, config: Coqpit): + super().__init__() + self._set_model_args(config) + + def _set_model_args(self, config: Coqpit): + """Set model arguments from the config. Override this.""" + pass + @abstractmethod def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict: """Forward pass for the model mainly used in training. @@ -73,7 +81,7 @@ class BaseModel(nn.Module, ABC): ... return outputs_dict, loss_dict - def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + def train_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None: """Create visualizations and waveform examples for training. For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to @@ -87,7 +95,7 @@ class BaseModel(nn.Module, ABC): Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ - return None, None + pass @abstractmethod def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: @@ -106,9 +114,9 @@ class BaseModel(nn.Module, ABC): ... return outputs_dict, loss_dict - def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None: """The same as `train_log()`""" - return None, None + pass @abstractmethod def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None: