Loss Scaling / Gradient Scaling was mentioned in <ahref="#2024/03/01-Mixed-Precision-Training">Mixed-Precision Training</a>as one of the 3 techniques, but there are many points to be careful withwhen in practice.

<h2 id="overview-typical-use-case">Overview: Typical Use Case</h2><p>Here’s an overview of how to use amp.GradScaler adaptedfrom <ahref=”https://pytorch.org/docs/stable/amp.html#gradient-scaling”>PyTorchofficial doc</a>.</p><h3 id="background">Background</h3><p>If the forward pass for a particular op has float16inputs, under <ahref=”https://pytorch.org/docs/stable/amp.html”>Automatic MixedPrecision package - torch.amp</a>, the backward pass for that op willproduce gradients of the same data type - float16 .Gradient values with small magnitudes may not be representable infloat16. These values will flush to zero (“underflow”), sothe update for the corresponding parameters will be lost.</p><h3 id="code">Code</h3><ol type="1"><li>scaler.scale(loss).backward(): To prevent underflow,“gradient scaling’ multiplies the network’s loss(es) by a scale factorand invokes a backward pass on the scaled loss(es). In this way, thegradients on all parameters are scaled by this same factor and we don’thave to worry about them flush to zero. scaler.scale(loss)multiplies a given loss by scaler’s current scale factor.We then call backward on this scaled loss.</li><li>scaler.step(optimizer): After back-propagation, alllearnable parameters get their gradients, which are scaled to preventunderflow. Before applying whatever learning algorithm (Adam, SGD, …) onthem, we have to unscale them so the amount to be updated is correct.scaler.step(optimizer) 1. unscales gradients, 2. callsoptimizer.step(), and does the previous two points safely:<ol type="1"><li>Internally invokes unscale_(optimizer) (unlessunscale_() was explicitly called for optimizerearlier in the iteration). As part of the unscale_(),gradients are checked for infs/NaNs to prevent overflow/underflow (Forwhy overflow can happen, check point 3 scaler.update)</li><li>If no inf/NaN gradients are found, invokesoptimizer.step() using the unscaled gradients. Otherwise,optimizer.step() is skipped to avoid corrupting theparams.</li></ol></li><li>scaler.update(): It would be great if we could justmultiply all gradients by a super big number so absolutely no underflowhappens, but doing so can cause overflow. The scaler estimates a goodscaling factor for each iteration, so neither underflow nor overflowhappens. scaler.update() updates scaler’sscale factor for next iteration.</li></ol>

<table><tr><td class="gutter"><pre>1
2
3
4
5
6
7
8
9
10
11
</pre></td><td class="code"><pre>scaler = torch.cuda.amp.GradScaler()

for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
</pre></td></tr></table>
<h2 id="working-with-unscaled-gradients---gradient-clipping"><ahref=”https://pytorch.org/docs/stable/notes/amp_examples.html#id4”>Workingwith Unscaled Gradients - Gradient clipping</a></h2><p>gradient clipping manipulates a set of gradients such that theirglobal norm <ahref=”https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html#torch.nn.utils.clip_grad_norm_”>torch.nn.utils.clip_grad_norm_()</a>or maximum magnitude <ahref=”https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_value_.html#torch.nn.utils.clip_grad_value_”>torch.nn.utils.clip_grad_value_()</a>is <= some user-imposed threshold.</p><p>The “gradients” here of course refer to the original, unscaledgradients. Therefore, you need to callscaler.unscale_(optimizer) before clipping.</p>
<table><tr><td class="gutter"><pre>1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
</pre></td><td class="code"><pre>scaler = GradScaler()

for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()

# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)

# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

# optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
scaler.step(optimizer)

scaler.update()
</pre></td></tr></table>
<h2 id="working-with-scaled-gradients---gradient-accumulation"><ahref=”https://pytorch.org/docs/stable/notes/amp_examples.html#id6”>Workingwith Scaled Gradients - Gradient accumulation</a></h2><p>Gradient accumulation adds gradients over an effective batch of sizebatch_per_step * gradient_accumulation_steps(* num_procs if distributed). Operations related to scaledgradients should occur at effective batch granularity. The followinghappens at the end of each effective batch:</p><ul><li>inf/NaN checking</li><li>step skipping if inf/NaN grads are found</li><li>parameter update</li><li>scale update</li></ul><p>Within an effective batch, all grads you accumulate should all bescaled and the scale factor should remain unchanged.</p>
<table><tr><td class="gutter"><pre>1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
</pre></td><td class="code"><pre>scaler = GradScaler()

for epoch in epochs:
for micro_step in range(gradient_accumulation_steps):
input, target = get_data(epoch, micro_step)
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
loss = loss / gradient_accumulation_steps
# Accumulates scaled gradients.
scaler.scale(loss).backward()

# If you need to work with unscaled gradients,
# after all (scaled) grads for the upcoming step have been accumulated
# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
</pre></td></tr></table>
<p>These examples may seem too vanilla, check out <ahref=”https://github.com/karpathy/nanoGPT/blob/325be85d9be8c81b436728a420e85796c57dba7e/train.py#L290-L314”>nanoGPT’smixed precision training loop</a> for a lively combination of gradientaccumulation and gradient clipping.</p><h2 id="working-with-scaled-gradients---gradient-penalty"><ahref=”https://pytorch.org/docs/stable/notes/amp_examples.html#id7”>Workingwith Scaled Gradients - Gradient penalty</a></h2><p>What? Why?</p><p>https://discuss.pytorch.org/t/whats-the-use-of-scaled-grad-params-in-this-example-of-gradient-penalty-with-scaled-gradients/199741/3</p><h2 id="epilogue">Epilogue</h2><p>This wikipage from Deepgram provides a detailed view of what gradient scalingis about, but I don’t know why it just reads like AI-generated content.Maybe because it gives too many unnecessary details.</p>