mirror of https://github.com/coqui-ai/TTS.git
Modify `get_optimizer` to accept a model argument
This commit is contained in:
parent
003e5579e8
commit
d4deb2716f
|
@ -1,5 +1,5 @@
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ def get_scheduler(
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(
|
def get_optimizer(
|
||||||
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module
|
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module = None, parameters: List = None
|
||||||
) -> torch.optim.Optimizer:
|
) -> torch.optim.Optimizer:
|
||||||
"""Find, initialize and return a optimizer.
|
"""Find, initialize and return a optimizer.
|
||||||
|
|
||||||
|
@ -66,4 +66,6 @@ def get_optimizer(
|
||||||
optimizer = getattr(module, "RAdam")
|
optimizer = getattr(module, "RAdam")
|
||||||
else:
|
else:
|
||||||
optimizer = getattr(torch.optim, optimizer_name)
|
optimizer = getattr(torch.optim, optimizer_name)
|
||||||
return optimizer(model.parameters(), lr=lr, **optimizer_params)
|
if model is not None:
|
||||||
|
parameters = model.parameters()
|
||||||
|
return optimizer(parameters, lr=lr, **optimizer_params)
|
||||||
|
|
Loading…
Reference in New Issue