mirror of https://github.com/coqui-ai/TTS.git
update set_weight_decay
This commit is contained in:
parent
8565c508e4
commit
99d7f2a666
|
@ -182,9 +182,9 @@ def weight_decay(optimizer):
|
||||||
return optimizer, current_lr
|
return optimizer, current_lr
|
||||||
|
|
||||||
|
|
||||||
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v"}):
|
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
|
||||||
"""
|
"""
|
||||||
Skip biases, BatchNorm parameters for weight decay
|
Skip biases, BatchNorm parameters, rnns.
|
||||||
and attention projection layer v
|
and attention projection layer v
|
||||||
"""
|
"""
|
||||||
decay = []
|
decay = []
|
||||||
|
@ -192,7 +192,8 @@ def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v"}):
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if not param.requires_grad:
|
if not param.requires_grad:
|
||||||
continue
|
continue
|
||||||
if len(param.shape) == 1 or name in skip_list:
|
|
||||||
|
if len(param.shape) == 1 or any([skip_name in name for skip_name in skip_list]):
|
||||||
no_decay.append(param)
|
no_decay.append(param)
|
||||||
else:
|
else:
|
||||||
decay.append(param)
|
decay.append(param)
|
||||||
|
|
Loading…
Reference in New Issue