mirror of https://github.com/coqui-ai/TTS.git
Add init_from_config
This commit is contained in:
parent
90cc45dd4e
commit
30cfafce56
|
@ -20,6 +20,7 @@ class BaseVocoder(BaseModel):
|
|||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self._set_model_args(config)
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
"""Setup model args based on the config type.
|
||||
|
|
|
@ -339,3 +339,7 @@ class Wavegrad(BaseVocoder):
|
|||
noise_schedule = self.config["train_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
self.compute_noise_level(betas)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "WavegradConfig"):
|
||||
return Wavegrad(config)
|
||||
|
|
|
@ -631,3 +631,7 @@ class Wavernn(BaseVocoder):
|
|||
def get_criterion(self):
|
||||
# define train functions
|
||||
return WaveRNNLoss(self.args.mode)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "WavernnConfig"):
|
||||
return Wavernn(config)
|
||||
|
|
Loading…
Reference in New Issue