Yao Lirong's Blog

Decoupled Weight Decay Regularization (SGDW & AdamW)

2024/03/13
loading

The paper Decoupled Weight Decay Regularization mainly introduces AdamW, which is the SOTA optimizer since then. It investigates why Adam with L2 regularization sometimes performs worse than SGD with L2 regularization. It demonstrates weight decay and L2 regularization, two things people usually draw an equal sign, are not the same. And it shows weight decay is the ultimate go-to choice.

Weight decay and L2 regularization are equivalent in SGD when set L2 regularizer $\lambda' = \frac \lambda \alpha$, which is our common practice. The situation is more complicated with adaptive gradient algorithms like Adam. Adam performs much better with weight decay and the authors propose the new SOTA optimizer AdamW (Adam with decoupled weight decay). All the conclusions and main finding can be found in the first 2 pages of the paper and mostly in the Introduction section. I did not read the math.

This blogpost from Fast.ai demonstrates how the two methods are different in code, a bit easier to understand than the paper which doesn’t provide a comparision.

Weight Decay in Transformers

AdamW is the go-to optimizer for LLM these days. Researchers chose it because LLMs are hard to train and rarely overfit, and Adam is the best choice when convergence speed is considered (reference). People have also found AdamW usually performs best with big weight decay coefficient like 0.05 or 0.1 (zhihu question, ViT paper: Training & Fine-tuning section)

When we apply weight decay in transformers, we apply it to all layers except LayerNorm and bias layers.

In nanoGPT, Karpathy filtered them out using:

1
2
3
4
5
6
7
8
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]

One caveat is that, in earlier versions, Karpathy did NOT weight decay embeddings:

1
2
3
4
5
blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding)
...
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)

I couldn’t find any instruction on whether you should decay embeddings or not when training a transformer, but Hugging Face’s transformer implementation also decays embeddings, in line with Karpathy’s latest implementation.

1
2
3
# get_parameter_names(model, name) excludes layers with `name`
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]

CATALOG
  1. 1. Weight Decay in Transformers