update set_weight_decay

This commit is contained in:
Eren Golge 2019-09-28 15:31:18 +02:00
parent 8565c508e4
commit 99d7f2a666
1 changed files with 4 additions and 3 deletions

View File

@ -182,9 +182,9 @@ def weight_decay(optimizer):
return optimizer, current_lr return optimizer, current_lr
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v"}): def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
""" """
Skip biases, BatchNorm parameters for weight decay Skip biases, BatchNorm parameters, rnns.
and attention projection layer v and attention projection layer v
""" """
decay = [] decay = []
@ -192,7 +192,8 @@ def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v"}):
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if not param.requires_grad: if not param.requires_grad:
continue continue
if len(param.shape) == 1 or name in skip_list:
if len(param.shape) == 1 or any([skip_name in name for skip_name in skip_list]):
no_decay.append(param) no_decay.append(param)
else: else:
decay.append(param) decay.append(param)