Modify `get_optimizer` to accept a model argument

This commit is contained in:
Eren Gölge 2021-08-07 21:47:48 +00:00
parent 003e5579e8
commit d4deb2716f
1 changed files with 5 additions and 3 deletions

View File

@ -1,5 +1,5 @@
import importlib import importlib
from typing import Dict from typing import Dict, List
import torch import torch
@ -48,7 +48,7 @@ def get_scheduler(
def get_optimizer( def get_optimizer(
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module = None, parameters: List = None
) -> torch.optim.Optimizer: ) -> torch.optim.Optimizer:
"""Find, initialize and return a optimizer. """Find, initialize and return a optimizer.
@ -66,4 +66,6 @@ def get_optimizer(
optimizer = getattr(module, "RAdam") optimizer = getattr(module, "RAdam")
else: else:
optimizer = getattr(torch.optim, optimizer_name) optimizer = getattr(torch.optim, optimizer_name)
return optimizer(model.parameters(), lr=lr, **optimizer_params) if model is not None:
parameters = model.parameters()
return optimizer(parameters, lr=lr, **optimizer_params)