diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 08f81cc2..181ea1e0 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -155,7 +155,7 @@ def main(args): # pylint: disable=redefined-outer-name ap = AudioProcessor(**c.audio) model = setup_speaker_encoder_model(c) - optimizer = RAdam(model.parameters(), lr=c.lr) + optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=c.wd) # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=False)