From d4deb2716f135297d4218dc6fa00533c0ca8b403 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 7 Aug 2021 21:47:48 +0000 Subject: [PATCH] Modify `get_optimizer` to accept a model argument --- TTS/utils/trainer_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/TTS/utils/trainer_utils.py b/TTS/utils/trainer_utils.py index 29915527..577f1a8d 100644 --- a/TTS/utils/trainer_utils.py +++ b/TTS/utils/trainer_utils.py @@ -1,5 +1,5 @@ import importlib -from typing import Dict +from typing import Dict, List import torch @@ -48,7 +48,7 @@ def get_scheduler( def get_optimizer( - optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module + optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module = None, parameters: List = None ) -> torch.optim.Optimizer: """Find, initialize and return a optimizer. @@ -66,4 +66,6 @@ def get_optimizer( optimizer = getattr(module, "RAdam") else: optimizer = getattr(torch.optim, optimizer_name) - return optimizer(model.parameters(), lr=lr, **optimizer_params) + if model is not None: + parameters = model.parameters() + return optimizer(parameters, lr=lr, **optimizer_params)