From 16db5159f12568a2f9fc9ab150a7f00f2c08897b Mon Sep 17 00:00:00 2001 From: Eren Date: Wed, 19 Sep 2018 14:25:30 +0200 Subject: [PATCH] Weight decay described here: http://www.fast.ai/2018/07/02/adam-weight-decay/ --- train.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 14def5e7..55811097 100644 --- a/train.py +++ b/train.py @@ -89,6 +89,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, # backpass and check the grad norm for spec losses loss.backward(retain_graph=True) + for group in optimizer.param_groups: + for param in group['params']: + param.data = param.data.add(-c.wd * group['lr'], param.data) grad_norm, skip_flag = check_update(model, 1) if skip_flag: optimizer.zero_grad() @@ -98,6 +101,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, # backpass and check the grad norm for stop loss stop_loss.backward() + for group in optimizer_st.param_groups: + for param in group['params']: + param.data = param.data.add(-c.wd * group['lr'], param.data) grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5) if skip_flag: optimizer_st.zero_grad() @@ -390,9 +396,9 @@ def main(args): model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r) print(" | > Num output units : {}".format(ap.num_freq), flush=True) - optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=c.wd) + optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0) optimizer_st = optim.Adam( - model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=c.wd) + model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0) criterion = L1LossMasked() criterion_st = nn.BCELoss()