Loss Scaling / Gradient Scaling was mentioned in <ahref="#2024/03/01-Mixed-Precision-Training">Mixed-Precision Training</a>as one of the 3 techniques, but there are many points to be careful withwhen in practice.
<h2 id="overview-typical-use-case">Overview: Typical Use Case</h2><p>Here’s an overview of how to use amp.GradScaler
adaptedfrom <ahref=”https://pytorch.org/docs/stable/amp.html#gradient-scaling”>PyTorchofficial doc</a>.</p><h3 id="background">Background</h3><p>If the forward pass for a particular op has float16
inputs, under <ahref=”https://pytorch.org/docs/stable/amp.html”>Automatic MixedPrecision package - torch.amp</a>, the backward pass for that op willproduce gradients of the same data type - float16
.Gradient values with small magnitudes may not be representable infloat16
. These values will flush to zero (“underflow”), sothe update for the corresponding parameters will be lost.</p><h3 id="code">Code</h3><ol type="1"><li>scaler.scale(loss).backward()
: To prevent underflow,“gradient scaling’ multiplies the network’s loss(es) by a scale factorand invokes a backward pass on the scaled loss(es). In this way, thegradients on all parameters are scaled by this same factor and we don’thave to worry about them flush to zero. scaler.scale(loss)
multiplies a given loss by scaler
’s current scale factor.We then call backward on this scaled loss.</li><li>scaler.step(optimizer)
: After back-propagation, alllearnable parameters get their gradients, which are scaled to preventunderflow. Before applying whatever learning algorithm (Adam, SGD, …) onthem, we have to unscale them so the amount to be updated is correct.scaler.step(optimizer)
1. unscales gradients, 2. callsoptimizer.step()
, and does the previous two points safely:<ol type="1"><li>Internally invokes unscale_(optimizer)
(unlessunscale_()
was explicitly called for optimizer
earlier in the iteration). As part of the unscale_()
,gradients are checked for infs/NaNs to prevent overflow/underflow (Forwhy overflow can happen, check point 3 scaler.update
)</li><li>If no inf/NaN gradients are found, invokesoptimizer.step()
using the unscaled gradients. Otherwise,optimizer.step()
is skipped to avoid corrupting theparams.</li></ol></li><li>scaler.update()
: It would be great if we could justmultiply all gradients by a super big number so absolutely no underflowhappens, but doing so can cause overflow. The scaler estimates a goodscaling factor for each iteration, so neither underflow nor overflowhappens. scaler.update()
updates scaler
’sscale factor for next iteration.</li></ol>
2
3
4
5
6
7
8
9
10
11
</pre></td><td class="code"><pre>scaler = torch.cuda.amp.GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
</pre></td></tr></table>torch.nn.utils.clip_grad_norm_()
</a>or maximum magnitude <ahref=”https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_value_.html#torch.nn.utils.clip_grad_value_”>torch.nn.utils.clip_grad_value_()
</a>is <= some user-imposed threshold.</p><p>The “gradients” here of course refer to the original, unscaledgradients. Therefore, you need to callscaler.unscale_(optimizer)
before clipping.</p>
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
</pre></td><td class="code"><pre>scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
scaler.step(optimizer)
scaler.update()
</pre></td></tr></table>batch_per_step * gradient_accumulation_steps
(* num_procs
if distributed). Operations related to scaledgradients should occur at effective batch granularity. The followinghappens at the end of each effective batch:</p><ul><li>inf/NaN checking</li><li>step skipping if inf/NaN grads are found</li><li>parameter update</li><li>scale update</li></ul><p>Within an effective batch, all grads you accumulate should all bescaled and the scale factor should remain unchanged.</p>
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
</pre></td><td class="code"><pre>scaler = GradScaler()
for epoch in epochs:
for micro_step in range(gradient_accumulation_steps):
input, target = get_data(epoch, micro_step)
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
loss = loss / gradient_accumulation_steps
# Accumulates scaled gradients.
scaler.scale(loss).backward()
# If you need to work with unscaled gradients,
# after all (scaled) grads for the upcoming step have been accumulated
# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
</pre></td></tr></table>