Use get_optimizer in Encoder

This commit is contained in:
Edresson Casanova 2022-03-10 13:58:17 -03:00
parent a436fe40a3
commit a9208e9edd
3 changed files with 8 additions and 4 deletions

View File

@ -20,7 +20,7 @@ from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.io import load_fsspec, copy_model_files from TTS.utils.io import load_fsspec, copy_model_files
from TTS.utils.radam import RAdam from trainer.trainer_utils import get_optimizer
from TTS.utils.training import check_update from TTS.utils.training import check_update
torch.backends.cudnn.enabled = True torch.backends.cudnn.enabled = True
@ -244,7 +244,7 @@ def main(args): # pylint: disable=redefined-outer-name
ap = AudioProcessor(**c.audio) ap = AudioProcessor(**c.audio)
model = setup_speaker_encoder_model(c) model = setup_speaker_encoder_model(c)
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=c.wd) optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model)
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True) meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)

View File

@ -32,9 +32,13 @@ class BaseEncoderConfig(BaseTrainingConfig):
loss: str = "angleproto" loss: str = "angleproto"
grad_clip: float = 3.0 grad_clip: float = 3.0
lr: float = 0.0001 lr: float = 0.0001
optimizer: str = "radam"
optimizer_params: Dict = field(default_factory=lambda: {
"betas": [0.9, 0.999],
"weight_decay": 0
})
lr_decay: bool = False lr_decay: bool = False
warmup_steps: int = 4000 warmup_steps: int = 4000
wd: float = 1e-6
# logging params # logging params
tb_model_param_stats: bool = False tb_model_param_stats: bool = False

View File

@ -264,7 +264,7 @@ class BaseTTSConfig(BaseTrainingConfig):
# dataset # dataset
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer # optimizer
optimizer: str = None optimizer: str = "radam"
optimizer_params: dict = None optimizer_params: dict = None
# scheduler # scheduler
lr_scheduler: str = "" lr_scheduler: str = ""