diff --git a/train.py b/train.py index 7ca06ed8..791c5d4f 100644 --- a/train.py +++ b/train.py @@ -361,7 +361,8 @@ def main(args): c.r) optimizer = optim.Adam(model.parameters(), lr=c.lr) - optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr) + optimizer_st = optim.SGD(model.decoder.stopnet.parameters(), lr=0.01, + momentum=0.9, nesterov=True) criterion = L1LossMasked() criterion_st = nn.BCELoss()