mirror of https://github.com/coqui-ai/TTS.git
Use get_optimizer in Encoder
This commit is contained in:
parent
a436fe40a3
commit
a9208e9edd
|
@ -20,7 +20,7 @@ from TTS.tts.datasets import load_tts_samples
|
|||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.io import load_fsspec, copy_model_files
|
||||
from TTS.utils.radam import RAdam
|
||||
from trainer.trainer_utils import get_optimizer
|
||||
from TTS.utils.training import check_update
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
@ -244,7 +244,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
ap = AudioProcessor(**c.audio)
|
||||
model = setup_speaker_encoder_model(c)
|
||||
|
||||
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=c.wd)
|
||||
optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)
|
||||
|
|
|
@ -32,9 +32,13 @@ class BaseEncoderConfig(BaseTrainingConfig):
|
|||
loss: str = "angleproto"
|
||||
grad_clip: float = 3.0
|
||||
lr: float = 0.0001
|
||||
optimizer: str = "radam"
|
||||
optimizer_params: Dict = field(default_factory=lambda: {
|
||||
"betas": [0.9, 0.999],
|
||||
"weight_decay": 0
|
||||
})
|
||||
lr_decay: bool = False
|
||||
warmup_steps: int = 4000
|
||||
wd: float = 1e-6
|
||||
|
||||
# logging params
|
||||
tb_model_param_stats: bool = False
|
||||
|
|
|
@ -264,7 +264,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
# dataset
|
||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||
# optimizer
|
||||
optimizer: str = None
|
||||
optimizer: str = "radam"
|
||||
optimizer_params: dict = None
|
||||
# scheduler
|
||||
lr_scheduler: str = ""
|
||||
|
|
Loading…
Reference in New Issue