From 30cfafce569b1d8d09b6efe10583ba68f6ae7b7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 7 Dec 2021 08:56:57 +0000 Subject: [PATCH] Add init_from_config --- TTS/vocoder/models/base_vocoder.py | 1 + TTS/vocoder/models/wavegrad.py | 4 ++++ TTS/vocoder/models/wavernn.py | 4 ++++ 3 files changed, 9 insertions(+) diff --git a/TTS/vocoder/models/base_vocoder.py b/TTS/vocoder/models/base_vocoder.py index 9d6ef26f..2728525c 100644 --- a/TTS/vocoder/models/base_vocoder.py +++ b/TTS/vocoder/models/base_vocoder.py @@ -20,6 +20,7 @@ class BaseVocoder(BaseModel): def __init__(self, config): super().__init__(config) + self._set_model_args(config) def _set_model_args(self, config: Coqpit): """Setup model args based on the config type. diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 00142c91..9d6e431c 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -339,3 +339,7 @@ class Wavegrad(BaseVocoder): noise_schedule = self.config["train_noise_schedule"] betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) self.compute_noise_level(betas) + + @staticmethod + def init_from_config(config: "WavegradConfig"): + return Wavegrad(config) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index b5b2343a..68f9b2c8 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -631,3 +631,7 @@ class Wavernn(BaseVocoder): def get_criterion(self): # define train functions return WaveRNNLoss(self.args.mode) + + @staticmethod + def init_from_config(config: "WavernnConfig"): + return Wavernn(config)