mirror of https://github.com/coqui-ai/TTS.git
Add `_set_model_args` to BaseModel
This commit is contained in:
parent
9a0d8fa027
commit
16b70be0dd
16
TTS/model.py
16
TTS/model.py
|
@ -22,6 +22,14 @@ class BaseModel(nn.Module, ABC):
|
||||||
- 1D tensors `batch x 1`
|
- 1D tensors `batch x 1`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Coqpit):
|
||||||
|
super().__init__()
|
||||||
|
self._set_model_args(config)
|
||||||
|
|
||||||
|
def _set_model_args(self, config: Coqpit):
|
||||||
|
"""Set model arguments from the config. Override this."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
|
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
|
||||||
"""Forward pass for the model mainly used in training.
|
"""Forward pass for the model mainly used in training.
|
||||||
|
@ -73,7 +81,7 @@ class BaseModel(nn.Module, ABC):
|
||||||
...
|
...
|
||||||
return outputs_dict, loss_dict
|
return outputs_dict, loss_dict
|
||||||
|
|
||||||
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
def train_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None:
|
||||||
"""Create visualizations and waveform examples for training.
|
"""Create visualizations and waveform examples for training.
|
||||||
|
|
||||||
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
|
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
|
||||||
|
@ -87,7 +95,7 @@ class BaseModel(nn.Module, ABC):
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||||
"""
|
"""
|
||||||
return None, None
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||||
|
@ -106,9 +114,9 @@ class BaseModel(nn.Module, ABC):
|
||||||
...
|
...
|
||||||
return outputs_dict, loss_dict
|
return outputs_dict, loss_dict
|
||||||
|
|
||||||
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None:
|
||||||
"""The same as `train_log()`"""
|
"""The same as `train_log()`"""
|
||||||
return None, None
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None:
|
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None:
|
||||||
|
|
Loading…
Reference in New Issue