Loss Scaling / Gradient Scaling was mentioned in Mixed-Precision Training as one of the 3 techniques, but there are many points to be careful with when in practice.
Overview: Typical Use Case
Here’s an overview of how to use amp.GradScaler
adapted
from PyTorch
official doc.
Background
If the forward pass for a particular op has float16
inputs, under Automatic Mixed
Precision package - torch.amp, the backward pass for that op will
produce gradients of the same data type - float16
.
Gradient values with small magnitudes may not be representable in
float16
. These values will flush to zero (“underflow”), so
the update for the corresponding parameters will be lost.
Code
scaler.scale(loss).backward()
: To prevent underflow, “gradient scaling’ multiplies the network’s loss(es) by a scale factor and invokes a backward pass on the scaled loss(es). In this way, the gradients on all parameters are scaled by this same factor and we don’t have to worry about them flush to zero.scaler.scale(loss)
multiplies a given loss byscaler
’s current scale factor. We then call backward on this scaled loss.scaler.step(optimizer)
: After back-propagation, all learnable parameters get their gradients, which are scaled to prevent underflow. Before applying whatever learning algorithm (Adam, SGD, …) on them, we have to unscale them so the amount to be updated is correct.scaler.step(optimizer)
1. unscales gradients, 2. callsoptimizer.step()
, and does the previous two points safely:- Internally invokes
unscale_(optimizer)
(unlessunscale_()
was explicitly called foroptimizer
earlier in the iteration). As part of theunscale_()
, gradients are checked for infs/NaNs to prevent overflow/underflow (For why overflow can happen, check point 3scaler.update
) - If no inf/NaN gradients are found, invokes
optimizer.step()
using the unscaled gradients. Otherwise,optimizer.step()
is skipped to avoid corrupting the params.
- Internally invokes
scaler.update()
: It would be great if we could just multiply all gradients by a super big number so absolutely no underflow happens, but doing so can cause overflow. The scaler estimates a good scaling factor for each iteration, so neither underflow nor overflow happens.scaler.update()
updatesscaler
’s scale factor for next iteration.
1 | scaler = torch.cuda.amp.GradScaler() |
Working with Unscaled Gradients - Gradient clipping
gradient clipping manipulates a set of gradients such that their
global norm torch.nn.utils.clip_grad_norm_()
or maximum magnitude torch.nn.utils.clip_grad_value_()
is <= some user-imposed threshold.
The “gradients” here of course refer to the original, unscaled
gradients. Therefore, you need to call
scaler.unscale_(optimizer)
before clipping.
1 | scaler = GradScaler() |
Working with Scaled Gradients - Gradient accumulation
Gradient accumulation adds gradients over an effective batch of size
batch_per_step * gradient_accumulation_steps
(* num_procs
if distributed). Operations related to scaled
gradients should occur at effective batch granularity. The following
happens at the end of each effective batch:
- inf/NaN checking
- step skipping if inf/NaN grads are found
- parameter update
- scale update
Within an effective batch, all grads you accumulate should all be scaled and the scale factor should remain unchanged.
1 | scaler = GradScaler() |
These examples may seem too vanilla, check out nanoGPT’s mixed precision training loop for a lively combination of gradient accumulation and gradient clipping.
Working with Scaled Gradients - Gradient penalty
What? Why?
https://discuss.pytorch.org/t/whats-the-use-of-scaled-grad-params-in-this-example-of-gradient-penalty-with-scaled-gradients/199741/3
Epilogue
This wiki page from Deepgram provides a detailed view of what gradient scaling is about, but I don’t know why it just reads like AI-generated content. Maybe because it gives too many unnecessary details.