Engineering Concerns for Deep Learning Training Loops
Table of Contents
- The Core Step Order
- Optimizer Configuration
- Learning Rate Scheduling
- Mixed Precision Training
- Gradient Clipping
- Gradient Accumulation
- Model Mode Management
- Tricks and Engineering Notes
- References
⚙️ The Core Step Order
Every neural network training loop reduces to the same fundamental rhythm:
flowchart LR
A["zero_grad()"] --> B["forward pass"]
B --> C["compute loss"]
C --> D["backward()"]
D --> E["clip gradients"]
E --> F["optimizer.step()"]
F --> G["scheduler.step()"]
G --> A
The ordering is not arbitrary — every step has a causal dependency on the previous one.
| Step | Why it must come here |
|---|---|
zero_grad() |
Clears accumulated gradients from the previous step; must precede backward |
forward + loss |
Builds the computation graph |
backward() |
Populates .grad on every parameter via reverse-mode autodiff |
clip_grad_norm_ |
Operates on .grad tensors — must come after backward, before the update |
optimizer.step() |
Reads .grad to update weights |
scheduler.step() |
Adjusts the learning rate for the next step |
Clipping gradients after the optimizer step is a silent bug — the gradients have already been consumed and the weights updated. The clip has no effect. This is easy to introduce when refactoring.
📐 Optimizer Configuration
AdamW and Decoupled Weight Decay
Standard Adam applies weight decay by adding an L2 penalty to the loss:
\[\mathcal{L}_{\text{reg}} = \mathcal{L} + \frac{\lambda}{2} \|\theta\|^2\]
This folds the decay into the gradient: \(g_t \leftarrow g_t + \lambda \theta_t\). Because Adam scales gradients by its second-moment estimate \(\hat{v}_t\), the effective decay rate varies per parameter — parameters with large gradient variance get less decay than intended.
AdamW (Loshchilov & Hutter, 2019) decouples weight decay from the gradient update, applying it directly to the weights after the Adam step:
\[\theta_{t+1} \leftarrow \theta_{t+1}^{\text{Adam}} - \eta \lambda \theta_t\]
This restores weight decay to its intended regularization semantics, independent of the adaptive learning rate.
Parameter Group Weight Decay
Even with AdamW, applying a uniform weight_decay to all parameters is incorrect. Weight decay pushes parameters toward zero — this is meaningful regularization for weight matrices, where zero means “no connection.” It is not meaningful for:
- Norm scale parameters (e.g.
RMSNormweights,LayerNormweight/bias): these control layer output magnitude. Decaying them toward zero suppresses normalization. - Bias parameters: these represent learned offsets; zero is an arbitrary target with no regularization justification.
The standard heuristic, due to nanoGPT, separates parameters by tensor dimension:
decay_params = [p for n, p in model.named_parameters() if p.dim() >= 2]
nodecay_params = [p for n, p in model.named_parameters() if p.dim() < 2]
optimizer = torch.optim.AdamW([
{"params": decay_params, "weight_decay": weight_decay},
{"params": nodecay_params, "weight_decay": 0.0},
], lr=learning_rate)Weight matrices are always at least 2-dimensional by construction. All parameters that should not be decayed — biases (\(d=1\)), norm scale vectors (\(d=1\)) — are 1-dimensional. The geometric property of the tensor directly encodes the semantic distinction.
Embedding tables are 2D (vocab_size × emb_dim) and will receive weight decay under this heuristic. Whether to decay embeddings is debated — some practitioners exclude them explicitly by name. The effect is generally small.
📈 Learning Rate Scheduling
A fixed learning rate is rarely optimal. At the start of training, weights are random and gradients are noisy — a large step risks divergence. Near convergence, smaller steps are needed for precision. The solution is a schedule: \(\eta_t\) varies over training.
Linear Warmup
For the first \(T_w\) steps, ramp linearly from near-zero to the peak learning rate \(\eta_{\max}\):
\[\eta_t = \eta_{\max} \cdot \frac{t}{T_w}, \quad t < T_w\]
Warmup stabilizes early training by giving the optimizer time to build reliable second-moment estimates before taking large steps. GPT-3 warmed up over approximately 375M tokens; LLaMA 2 over 2000 steps.
Starting from \(\eta_0 = 0\) makes the first update a no-op. In practice, use a small \(\epsilon\) (e.g. 1e-8 × η_max) as the initial value.
Cosine Decay
After warmup, decay from \(\eta_{\max}\) to a minimum floor \(\eta_{\min}\) following a cosine curve:
\[\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\left(\pi \cdot \frac{t - T_w}{T_{\text{total}} - T_w}\right)\right), \quad t \geq T_w\]
Cosine decay is preferred over linear or step decay because the smooth curve spends more time near the minimum — it decelerates gradually rather than dropping sharply at fixed milestones, avoiding loss spikes.
The combined schedule:
xychart-beta
title "LR Schedule: Warmup + Cosine Decay"
x-axis "step" [0, Tw, Ttotal]
y-axis "lr"
line [0, 1.0, 0.1]
PyTorch Implementation
Two idiomatic approaches:
SequentialLR (declarative):
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
warmup = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
cosine = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=min_lr)
scheduler = SequentialLR(optimizer, schedulers=[warmup, cosine], milestones=[warmup_steps])LambdaLR (most flexible, used in nanoGPT and HuggingFace internals):
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1 + math.cos(math.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)scheduler.step() must always be called after optimizer.step(). Calling it before is a common mistake that updates the LR one step early.
⚡ Mixed Precision Training
Standard PyTorch training uses float32 (32 bits per value). Mixed precision runs the forward pass in a lower-precision dtype while preserving float32 for weight updates and loss computation.
Benefits: - ~2× reduction in activation memory - Faster matrix multiplications on modern hardware (Tensor Cores, Apple AMX) - Near-zero accuracy cost when the dtype is chosen correctly
bfloat16 vs float16
Both use 16 bits, but differ in bit allocation:
| Format | Sign | Exponent | Mantissa | Range | Precision |
|---|---|---|---|---|---|
float32 |
1 | 8 | 23 | \(\pm 3.4 \times 10^{38}\) | ~7 decimal digits |
bfloat16 |
1 | 8 | 7 | \(\pm 3.4 \times 10^{38}\) | ~2 decimal digits |
float16 |
1 | 5 | 10 | \(\pm 6.5 \times 10^{4}\) | ~3 decimal digits |
bfloat16 preserves the full exponent range of float32. This matters because gradients can span many orders of magnitude — float16’s narrow exponent causes overflow and underflow, requiring a GradScaler workaround. bfloat16 avoids this entirely. All modern LLMs (LLaMA, Gemma, GPT-4) train with bfloat16.
With torch.amp.autocast, only the forward pass activations are cast down. The model weights remain in float32; backprop computes gradients in float32.
GradScaler and the float16 Pitfall
When using float16, small gradients underflow to zero. GradScaler works around this by multiplying the loss by a large scale factor \(s\) before backprop, then dividing out before the optimizer step:
\[\text{scaled\_loss} = s \cdot \mathcal{L} \quad \Rightarrow \quad \text{scaled\_grad} = s \cdot \nabla_\theta \mathcal{L}\]
If you clip gradients while they are still scaled, you are comparing \(\|s \cdot g\|\) against max_norm — completely wrong units. You must call scaler.unscale_(optimizer) before clip_grad_norm_:
scaler.scale(loss).backward()
scaler.unscale_(optimizer) # ← required before clipping
clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()Prefer bfloat16 to avoid GradScaler entirely.
Device-Specific Considerations
Different hardware backends have different dtype support:
_AUTOCAST_DTYPE = {
"cuda": torch.bfloat16, # Ampere+ has hardware bfloat16 support
"cpu": torch.bfloat16, # Supported since PyTorch 1.10; speedup requires Intel AMX
"mps": torch.float16, # Apple Silicon; bfloat16 autocast support is inconsistent
}
def mp_training_autocast(device: torch.device):
dtype = _AUTOCAST_DTYPE.get(device.type)
if dtype is None:
return contextlib.nullcontext()
return torch.amp.autocast(device_type=device.type, dtype=dtype)contextlib.nullcontext() is the idiomatic way to write “optionally apply a context manager.” It is a context manager that does nothing, allowing a single with block to work whether or not AMP is active.
✂️ Gradient Clipping
Gradient clipping prevents exploding gradients — large loss spikes that send gradients to extreme magnitudes, destabilizing training. The standard approach is global norm clipping: if the \(\ell_2\) norm of all gradients stacked into a single vector exceeds a threshold \(c\), rescale them uniformly:
\[g \leftarrow g \cdot \min\left(1,\ \frac{c}{\|g\|_2}\right)\]
In PyTorch:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)A max norm of \(1.0\) is the conventional default for LLM pretraining (used by GPT-3, LLaMA, nanoGPT). Clipping must occur after backward() and before optimizer.step().
Per-parameter clipping changes the direction of the gradient update, not just its magnitude. Global norm clipping scales the entire gradient vector uniformly, preserving direction. This is a meaningful distinction: direction encodes which way to move, magnitude encodes how far.
🔁 Gradient Accumulation
Gradient accumulation simulates a larger effective batch size by running multiple forward–backward passes before taking a single optimizer step. If the true desired batch size is \(B\) but only \(B/k\) samples fit in memory, run \(k\) micro-steps, accumulating gradients across them, then step once.
The effective gradient is:
\[g_{\text{eff}} = \frac{1}{k} \sum_{i=1}^{k} \nabla_\theta \mathcal{L}(x_i)\]
which requires dividing each micro-batch loss by \(k\) before backprop so gradients average rather than sum.
Plain PyTorch
accumulation_steps = 4
optimizer.zero_grad()
for i, (x, y) in enumerate(train_loader):
loss = loss_fn(model(x), y) / accumulation_steps # scale before backward
loss.backward() # gradients accumulate in .grad
if (i + 1) % accumulation_steps == 0:
clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()clip_grad_norm_ must be called after all micro-steps for that accumulation window, not after each individual backward. Clipping per micro-step clips a partial gradient — the norm computed is not representative of the full accumulated gradient.
HuggingFace Accelerate
Accelerate wraps accumulation in a context manager that handles loss scaling and DDP sync suppression automatically:
accelerator = Accelerator(gradient_accumulation_steps=4)
model, optimizer, dataloader, scheduler = accelerator.prepare(
model, optimizer, dataloader, scheduler
)
for batch in dataloader:
with accelerator.accumulate(model): # manages no_sync + loss scaling
loss = loss_fn(model(inputs), labels)
accelerator.backward(loss) # internally: loss / accumulation_steps
optimizer.step()
scheduler.step()
optimizer.zero_grad()Inside accelerator.backward() (from source):
def backward(self, loss, **kwargs):
loss = loss / self.gradient_accumulation_steps # scale before backward
if self.scaler is not None:
self.scaler.scale(loss).backward(**kwargs) # fp16: GradScaler path
else:
loss.backward(**kwargs)The parameter group weight decay exclusion in HuggingFace Trainer also uses name-based filtering (see Tricks).
PyTorch Lightning
Lightning exposes accumulation as a single Trainer argument — it handles loss normalization and DDP sync suppression internally:
trainer = Trainer(accumulate_grad_batches=4)A scheduled variant allows decreasing accumulation as training stabilizes:
from lightning.pytorch.callbacks import GradientAccumulationScheduler
# 8 micro-steps for first 4 epochs, then 4, then 1
scheduler = GradientAccumulationScheduler(scheduling={0: 8, 4: 4, 8: 1})
trainer = Trainer(callbacks=scheduler)In Lightning’s automatic optimization mode the training_step just returns a loss — Lightning divides it by accumulate_grad_batches internally. The full optimizer step (including clipping) only fires on update steps; on accumulation steps the optimizer call is entirely skipped, not merely no-op’d.
With self.automatic_optimization = False, the user controls accumulation manually — identical to plain PyTorch but using self.manual_backward(loss) instead of loss.backward() to route through Lightning’s precision plugin.
DDP and Gradient Sync
In DistributedDataParallel training, PyTorch performs an all-reduce across ranks at the end of each backward() call to synchronize gradients. During accumulation micro-steps this is wasted communication — only the final backward in each window needs to sync.
The fix is to suppress gradient sync on intermediate steps using model.no_sync():
for i, (x, y) in enumerate(train_loader):
is_last_micro_step = (i + 1) % accumulation_steps == 0
ctx = contextlib.nullcontext() if is_last_micro_step else model.no_sync()
with ctx:
loss = loss_fn(model(x), y) / accumulation_steps
loss.backward()
if is_last_micro_step:
optimizer.step()
optimizer.zero_grad()Both Accelerate (accumulate()) and Lightning (_block_parallel_sync_behavior) handle this automatically. In a single-GPU setup no_sync is a no-op — the pattern is safe to use regardless of whether DDP is active.
🔄 Model Mode Management
PyTorch modules carry a training: bool flag (accessible as model.training) that changes the behavior of stateful layers:
| Layer | .train() mode |
.eval() mode |
|---|---|---|
nn.Dropout |
zeroes activations at rate \(p\) | identity (pass-through) |
nn.BatchNorm |
uses batch statistics | uses running statistics |
nn.LayerNorm, RMSNorm, nn.Linear |
no difference | no difference |
The standard pattern:
# During evaluation
model.eval()
with torch.no_grad():
...
model.train() # restore
# During training
model.train()These are orthogonal. model.eval() changes what the model computes (disables dropout stochasticity). torch.no_grad() changes how PyTorch tracks computation (disables gradient graph construction, saving memory and compute). Both are needed during inference.
The unconditional model.train() after eval incorrectly assumes the model was in training mode before the call. A context manager that captures and restores the original state is more robust:
from contextlib import contextmanager
@contextmanager
def eval_mode(model: torch.nn.Module):
was_training = model.training
model.eval()
try:
yield
finally:
model.train(was_training)🔑 Tricks and Engineering Notes
A running collection of non-obvious implementation details.
Parameter Group Weight Decay
Problem: AdamW applies weight decay uniformly to all parameters. Weight decay pushes values toward zero — a meaningful regularizer for weight matrices, but harmful for norm scale parameters (pushing layer outputs toward zero) and biases (arbitrary zero target).
Heuristic: separate parameters by ndim. All parameters that should not be decayed happen to be 1D; all weight matrices are ≥ 2D.
decay = [p for _, p in model.named_parameters() if p.dim() >= 2]
no_decay = [p for _, p in model.named_parameters() if p.dim() < 2]
optimizer = torch.optim.AdamW([
{"params": decay, "weight_decay": wd},
{"params": no_decay, "weight_decay": 0.0},
], lr=lr)Origin: nanoGPT (configure_optimizers). Widely adopted in LLM pretraining.
HuggingFace Trainer variant — name-based exclusion: rather than relying on ndim, HuggingFace filters by parameter name using regex patterns. This is more explicit and handles edge cases like named 2D norm parameters:
# From transformers/trainer.py
forbidden = [r"bias", r"layernorm", r"rmsnorm", r"(?:^|\.)norm(?:$|\.)", r"_norm(?:$|\.)"]
decay_names = get_parameter_names(model, [nn.LayerNorm], forbidden)
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters()
if n in decay_names and p.requires_grad],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters()
if n not in decay_names and p.requires_grad],
"weight_decay": 0.0,
},
]The dim >= 2 heuristic and the name-based approach should agree in practice. The name-based approach is more defensive when third-party layers use non-standard parameter shapes.
torch.compile()
Compiles the model’s computation graph with TorchInductor, fusing kernels and eliminating Python overhead. Applied after model initialization:
model = torch.compile(model)Typically yields 15–30% throughput improvement on CUDA with no code changes beyond this line. PyTorch 2.0+.
TF32 Matmul Precision
On NVIDIA Ampere and newer GPUs, PyTorch can use TF32 for float32 matrix multiplications — a 19-bit format with the same exponent range as float32 but reduced mantissa precision. The accuracy loss is negligible for deep learning while throughput improves significantly.
torch.set_float32_matmul_precision("high") # enable TF32Call this once before model initialization. Has no effect on non-CUDA devices.
📚 References
| Reference | Brief Summary | Link |
|---|---|---|
| Loshchilov & Hutter, “Decoupled Weight Decay Regularization” (2019) | Introduced AdamW; showed that L2 regularization in Adam does not equal weight decay | arXiv:1711.05101 |
| Karpathy, nanoGPT | Minimal, readable GPT-2 pretraining reference implementation | GitHub |
| PyTorch AMP documentation | Official guide to torch.amp.autocast and GradScaler |
pytorch.org |
| Touvron et al., “LLaMA 2” (2023) | Training config reference: bfloat16, cosine schedule, AdamW with group WD | arXiv:2307.09288 |
HuggingFace Accelerate source (accelerator.py) |
Reference implementation of gradient accumulation, loss scaling, no_sync dispatch, and GradScaler unscaling in a framework context |
GitHub |
HuggingFace Transformers Trainer (trainer.py) |
Reference implementation of parameter group weight decay via name-based regex filtering | GitHub |
| PyTorch Lightning training tricks docs | Canonical Lightning patterns for gradient accumulation, mixed precision, and gradient clipping via Trainer arguments | lightning.ai |