From 4998ece8d852624cb67a5e5dfbb29cc5bb918104 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 8 Apr 2021 11:12:39 +0200 Subject: [PATCH] allow configuration of optimziers from the config file --- TTS/bin/train_vocoder_gan.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index 99b8bba5..0af49c1f 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -476,10 +476,13 @@ def main(args): # pylint: disable=redefined-outer-name model_disc = setup_discriminator(c) # setup optimizers - optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0) - optimizer_disc = RAdam(model_disc.parameters(), - lr=c.lr_disc, - weight_decay=0) + # TODO: allow loading custom optimizers + optimizer_gen = None + optimizer_disc = None + optimizer_gen = getattr(torch.optim, c.optimizer) + optimizer_gen = optimizer_gen(lr=c.lr_gen, **c.optimizer_params) + optimizer_disc = getattr(torch.optim, c.optimizer) + optimizer_disc= optimizer_disc(lr=c.lr_gen, **c.optimizer_params) # schedulers scheduler_gen = None