Add init_from_config

This commit is contained in:
Eren Gölge 2021-12-07 08:56:57 +00:00
parent 90cc45dd4e
commit 30cfafce56
3 changed files with 9 additions and 0 deletions

View File

@ -20,6 +20,7 @@ class BaseVocoder(BaseModel):
def __init__(self, config):
super().__init__(config)
self._set_model_args(config)
def _set_model_args(self, config: Coqpit):
"""Setup model args based on the config type.

View File

@ -339,3 +339,7 @@ class Wavegrad(BaseVocoder):
noise_schedule = self.config["train_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas)
@staticmethod
def init_from_config(config: "WavegradConfig"):
return Wavegrad(config)

View File

@ -631,3 +631,7 @@ class Wavernn(BaseVocoder):
def get_criterion(self):
# define train functions
return WaveRNNLoss(self.args.mode)
@staticmethod
def init_from_config(config: "WavernnConfig"):
return Wavernn(config)