From a9208e9edd82f494aadd968c18fcdb6381e4dc0d Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 10 Mar 2022 13:58:17 -0300 Subject: [PATCH] Use get_optimizer in Encoder --- TTS/bin/train_encoder.py | 4 ++-- TTS/encoder/configs/base_encoder_config.py | 6 +++++- TTS/tts/configs/shared_configs.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 13f9368b..06aa41af 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -20,7 +20,7 @@ from TTS.tts.datasets import load_tts_samples from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.io import load_fsspec, copy_model_files -from TTS.utils.radam import RAdam +from trainer.trainer_utils import get_optimizer from TTS.utils.training import check_update torch.backends.cudnn.enabled = True @@ -244,7 +244,7 @@ def main(args): # pylint: disable=redefined-outer-name ap = AudioProcessor(**c.audio) model = setup_speaker_encoder_model(c) - optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=c.wd) + optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model) # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True) diff --git a/TTS/encoder/configs/base_encoder_config.py b/TTS/encoder/configs/base_encoder_config.py index 50164eaf..02b88d66 100644 --- a/TTS/encoder/configs/base_encoder_config.py +++ b/TTS/encoder/configs/base_encoder_config.py @@ -32,9 +32,13 @@ class BaseEncoderConfig(BaseTrainingConfig): loss: str = "angleproto" grad_clip: float = 3.0 lr: float = 0.0001 + optimizer: str = "radam" + optimizer_params: Dict = field(default_factory=lambda: { + "betas": [0.9, 0.999], + "weight_decay": 0 + }) lr_decay: bool = False warmup_steps: int = 4000 - wd: float = 1e-6 # logging params tb_model_param_stats: bool = False diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index a9b56ed4..dcc862e8 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -264,7 +264,7 @@ class BaseTTSConfig(BaseTrainingConfig): # dataset datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) # optimizer - optimizer: str = None + optimizer: str = "radam" optimizer_params: dict = None # scheduler lr_scheduler: str = ""