Yao Lirong's Blog

Gradient Scaling

2024/04/08
loading

Loss Scaling / Gradient Scaling was mentioned in Mixed-Precision Training as one of the 3 techniques, but there are many points to be careful with when in practice.

Overview: Typical Use Case

Here’s an overview of how to use amp.GradScaler adapted from PyTorch official doc.

Background

If the forward pass for a particular op has float16 inputs, under Automatic Mixed Precision package - torch.amp, the backward pass for that op will produce gradients of the same data type - float16 . Gradient values with small magnitudes may not be representable in float16. These values will flush to zero (“underflow”), so the update for the corresponding parameters will be lost.

Code

  1. scaler.scale(loss).backward(): To prevent underflow, “gradient scaling’ multiplies the network’s loss(es) by a scale factor and invokes a backward pass on the scaled loss(es). In this way, the gradients on all parameters are scaled by this same factor and we don’t have to worry about them flush to zero. scaler.scale(loss) multiplies a given loss by scaler’s current scale factor. We then call backward on this scaled loss.
  2. scaler.step(optimizer): After back-propagation, all learnable parameters get their gradients, which are scaled to prevent underflow. Before applying whatever learning algorithm (Adam, SGD, …) on them, we have to unscale them so the amount to be updated is correct. scaler.step(optimizer) 1. unscales gradients, 2. calls optimizer.step(), and does the previous two points safely:
    1. Internally invokes unscale_(optimizer) (unless unscale_() was explicitly called for optimizer earlier in the iteration). As part of the unscale_(), gradients are checked for infs/NaNs to prevent overflow/underflow (For why overflow can happen, check point 3 scaler.update)
    2. If no inf/NaN gradients are found, invokes optimizer.step() using the unscaled gradients. Otherwise, optimizer.step() is skipped to avoid corrupting the params.
  3. scaler.update(): It would be great if we could just multiply all gradients by a super big number so absolutely no underflow happens, but doing so can cause overflow. The scaler estimates a good scaling factor for each iteration, so neither underflow nor overflow happens. scaler.update() updates scaler’s scale factor for next iteration.
1
2
3
4
5
6
7
8
9
10
11
scaler = torch.cuda.amp.GradScaler()

for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

Working with Unscaled Gradients - Gradient clipping

gradient clipping manipulates a set of gradients such that their global norm torch.nn.utils.clip_grad_norm_() or maximum magnitude torch.nn.utils.clip_grad_value_() is <= some user-imposed threshold.

The “gradients” here of course refer to the original, unscaled gradients. Therefore, you need to call scaler.unscale_(optimizer) before clipping.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
scaler = GradScaler()

for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()

# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)

# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

# optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
scaler.step(optimizer)

scaler.update()

Working with Scaled Gradients - Gradient accumulation

Gradient accumulation adds gradients over an effective batch of size batch_per_step * gradient_accumulation_steps (* num_procs if distributed). Operations related to scaled gradients should occur at effective batch granularity. The following happens at the end of each effective batch:

  • inf/NaN checking
  • step skipping if inf/NaN grads are found
  • parameter update
  • scale update

Within an effective batch, all grads you accumulate should all be scaled and the scale factor should remain unchanged.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
scaler = GradScaler()

for epoch in epochs:
for micro_step in range(gradient_accumulation_steps):
input, target = get_data(epoch, micro_step)
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
loss = loss / gradient_accumulation_steps
# Accumulates scaled gradients.
scaler.scale(loss).backward()

# If you need to work with unscaled gradients,
# after all (scaled) grads for the upcoming step have been accumulated
# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

These examples may seem too vanilla, check out nanoGPT’s mixed precision training loop for a lively combination of gradient accumulation and gradient clipping.

Working with Scaled Gradients - Gradient penalty

What? Why?

https://discuss.pytorch.org/t/whats-the-use-of-scaled-grad-params-in-this-example-of-gradient-penalty-with-scaled-gradients/199741/3

Epilogue

This wiki page from Deepgram provides a detailed view of what gradient scaling is about, but I don’t know why it just reads like AI-generated content. Maybe because it gives too many unnecessary details.

CATALOG
  1. 1. Overview: Typical Use Case
    1. 1.1. Background
    2. 1.2. Code
  2. 2. Working with Unscaled Gradients - Gradient clipping
  3. 3. Working with Scaled Gradients - Gradient accumulation
  4. 4. Working with Scaled Gradients - Gradient penalty
  5. 5. Epilogue