Modify `get_optimizer` to accept a model argument

This commit is contained in:
Eren Gölge 2021-08-07 21:47:48 +00:00
parent 003e5579e8
commit d4deb2716f
1 changed files with 5 additions and 3 deletions

View File

@ -1,5 +1,5 @@
import importlib
from typing import Dict
from typing import Dict, List
import torch
@ -48,7 +48,7 @@ def get_scheduler(
def get_optimizer(
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module = None, parameters: List = None
) -> torch.optim.Optimizer:
"""Find, initialize and return a optimizer.
@ -66,4 +66,6 @@ def get_optimizer(
optimizer = getattr(module, "RAdam")
else:
optimizer = getattr(torch.optim, optimizer_name)
return optimizer(model.parameters(), lr=lr, **optimizer_params)
if model is not None:
parameters = model.parameters()
return optimizer(parameters, lr=lr, **optimizer_params)