Use get_optimizer in Encoder

This commit is contained in:
Edresson Casanova 2022-03-10 13:58:17 -03:00
parent a436fe40a3
commit a9208e9edd
3 changed files with 8 additions and 4 deletions

View File

@ -20,7 +20,7 @@ from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.io import load_fsspec, copy_model_files
from TTS.utils.radam import RAdam
from trainer.trainer_utils import get_optimizer
from TTS.utils.training import check_update
torch.backends.cudnn.enabled = True
@ -244,7 +244,7 @@ def main(args): # pylint: disable=redefined-outer-name
ap = AudioProcessor(**c.audio)
model = setup_speaker_encoder_model(c)
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=c.wd)
optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model)
# pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)

View File

@ -32,9 +32,13 @@ class BaseEncoderConfig(BaseTrainingConfig):
loss: str = "angleproto"
grad_clip: float = 3.0
lr: float = 0.0001
optimizer: str = "radam"
optimizer_params: Dict = field(default_factory=lambda: {
"betas": [0.9, 0.999],
"weight_decay": 0
})
lr_decay: bool = False
warmup_steps: int = 4000
wd: float = 1e-6
# logging params
tb_model_param_stats: bool = False

View File

@ -264,7 +264,7 @@ class BaseTTSConfig(BaseTrainingConfig):
# dataset
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer
optimizer: str = None
optimizer: str = "radam"
optimizer_params: dict = None
# scheduler
lr_scheduler: str = ""