mirror of https://github.com/coqui-ai/TTS.git
Modify `get_optimizer` to accept a model argument
This commit is contained in:
parent
003e5579e8
commit
d4deb2716f
|
@ -1,5 +1,5 @@
|
|||
import importlib
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -48,7 +48,7 @@ def get_scheduler(
|
|||
|
||||
|
||||
def get_optimizer(
|
||||
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module
|
||||
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module = None, parameters: List = None
|
||||
) -> torch.optim.Optimizer:
|
||||
"""Find, initialize and return a optimizer.
|
||||
|
||||
|
@ -66,4 +66,6 @@ def get_optimizer(
|
|||
optimizer = getattr(module, "RAdam")
|
||||
else:
|
||||
optimizer = getattr(torch.optim, optimizer_name)
|
||||
return optimizer(model.parameters(), lr=lr, **optimizer_params)
|
||||
if model is not None:
|
||||
parameters = model.parameters()
|
||||
return optimizer(parameters, lr=lr, **optimizer_params)
|
||||
|
|
Loading…
Reference in New Issue