🔬 AOTAutograd: Deep Dive
Table of Contents
- 1. The Cross-Backward Fusion Problem
- 2. make_fx and the Dispatch Protocol
- 3. The Joint Forward-Backward Graph
- 4. Functionalization in Depth
- 5. Graph Partitioning via Min-Cut Rematerialization
- 6. The AOTAutograd Runtime Wrapper
- 7. Relationship to torch.func
- References
1. The Cross-Backward Fusion Problem 🔑
The fundamental justification for AOTAutograd is a limitation of eager-mode PyTorch’s autograd engine that prevents a whole class of compiler optimizations. To state it precisely requires understanding how the autograd tape is built.
1.1 Eager-Mode Autograd Tape
In eager mode, every call to an ATen operator that touches a tensor with requires_grad=True appends a gradient function node (a Node subclass such as SinBackward or AddBackward0) to the autograd graph. This graph is constructed incrementally, one operator at a time, as Python executes. At .backward() time, PyTorch walks this graph in reverse topological order and calls each node’s apply() method — which dispatches a kernel, records its result as an input gradient, and propagates to the next node.
The key structural fact is: the backward graph does not exist as a data structure until .backward() is called, and even then it is executed node-by-node rather than as a whole program. A compiler that sits between the user’s forward code and the ATen kernels therefore sees, at any given moment, exactly one operator’s backward computation in isolation.
1.2 Why Adjacent Backward Ops Cannot Be Fused
Consider a loss computation \(\ell = \sin(x) + x^2\). The backward produces two gradient paths:
\[\frac{d\ell}{dx} = \cos(x) + 2x\]
In eager mode, when .backward() executes, the autograd engine calls SinBackward0.apply(upstream_grad) which dispatches aten::cos — a separate kernel launch. Then it calls PowBackward0.apply(upstream_grad) which dispatches aten::mul — another separate kernel launch. Finally, AddBackward0.apply(g1, g2) dispatches aten::add — a third kernel launch.
Each kernel launch reads and writes to high-bandwidth memory (HBM). For elementwise operations on a tensor of size \(n\), the arithmetic intensity (FLOPs per byte) is approximately \(1/4\) — far below the hardware’s theoretical compute-to-memory ratio, so execution is memory-bandwidth bound. Fusing cos + mul + add into a single kernel would replace three HBM round-trips with one, reducing backward memory bandwidth by roughly \(3\times\) for this fragment.
Crucially, when the compiler processes SinBackward0, it does not know that PowBackward0 will execute immediately afterward. The graph is not statically available. Fusion is impossible.
Even with a JIT compiler attached (e.g., NNC or Inductor in eager mode), each backward node is a separate compilation unit. Operators that would optimally fuse are compiled independently, and the resulting code incurs full memory-bandwidth cost for each.
1.3 AOTAutograd’s Solution
AOTAutograd resolves this by tracing both the forward and backward passes into a single static torch.fx GraphModule — the joint graph — before any computation occurs. The compiler receives the joint graph as a whole program and can apply cross-operator fusion across the backward nodes just as easily as across forward nodes.
In practice, AOTAutograd enables cross-op fusion in the backward that reduces backward memory bandwidth by 30–50% for elementwise-heavy models (as reported in the functorch benchmarks: NNC compilation cuts backward time from 1899.7 μs to 831.2 μs, a ~56% reduction).
This problem establishes the memory-bandwidth argument for backward fusion concretely.
Prerequisites: 1.2 Why Adjacent Backward Ops Cannot Be Fused
Let \(x \in \mathbb{R}^n\) be a float32 tensor. Consider the backward pass through \(\ell = \sin(x) + \cos(x)\), where each of SinBackward, CosBackward, and AddBackward launches a separate kernel. (a) Compute the total bytes read + written for the three-kernel unfused implementation, as a function of \(n\). (b) If the three ops are fused into a single kernel that reads \(x\) and \(\bar{\ell}\) once and writes \(dx\) once, what is the fused byte count? (c) Derive the bandwidth-reduction factor as \(n \to \infty\).
Key insight: Each kernel reads all its inputs and writes all its outputs; intermediate results between fused ops need never touch HBM.
Sketch: - Unfused. SinBackward reads \(x\) (4n bytes) and \(\bar{\ell}\) (4n bytes), writes \(\partial_x^{(1)}\) (4n bytes) → 12n bytes. CosBackward: reads \(x\) + \(\bar{\ell}\), writes \(\partial_x^{(2)}\) → 12n bytes. AddBackward: reads \(\partial_x^{(1)}\) + \(\partial_x^{(2)}\), writes \(dx\) → 12n bytes. Total: 36n bytes. - Fused. One kernel reads \(x\) and \(\bar{\ell}\), accumulates both gradients in registers, writes \(dx\). Total: $4n + 4n + 4n = $ 12n bytes. - Reduction factor: \(36n / 12n = \mathbf{3\times}\).
2. make_fx and the Dispatch Protocol 📐
2.1 The PyTorch Dispatch Stack
Every PyTorch operator call passes through a layered dispatch stack managed by the C10 Dispatcher. Simplified, the stack (from outermost to innermost) is:
flowchart TD
A["Python API call
torch.sin(x)"]
B["__torch_function__
(Python subclass override)"]
C["Autograd dispatch key
(builds tape, computes gradient formulas)"]
D["Batching dispatch key
(vmap)"]
E["__torch_dispatch__
(TorchDispatchMode / Tensor subclass)"]
F["Backend kernel
(CUDA / CPU ATen)"]
A --> B
B --> C
C --> D
D --> E
E --> F
Figure 1: Simplified PyTorch dispatch stack. __torch_function__ fires before dispatch resolution; __torch_dispatch__ fires after all higher dispatch keys, at the bottom of the Python-accessible portion of the stack.
2.2 Why torch_dispatch and Not torch_function
__torch_function__ is a Python-level interception hook. It fires when a torch.* API function is called on a subclass, before the call is dispatched into C++. It sees Python-level ops like torch.sin(x) or torch.add(x, y). Critically, it does not see:
- The decomposed ATen kernels that a composite op delegates to (composite C++ implementations are invisible to Python)
- The backward kernels invoked by
torch.autograd.grad— because those fire inside the C++ autograd engine, below__torch_function__
__torch_dispatch__ fires after the dispatch stack has resolved — after autograd, after batching, at the deepest Python-accessible point before the backend kernel. It intercepts the raw ATen op calls such as aten::sin.default. Consequently:
- It sees every individual ATen op, including those generated by composite decompositions
- When
torch.autograd.gradruns the backward pass inside the autograd engine, the gradient formula kernels (e.g.,aten::cosfromsin’s backward) are issued as ATen ops — and__torch_dispatch__captures them too
This is why make_fx can record a complete joint graph containing both forward and backward ATen ops.
Formally, __torch_dispatch__ sits at the Python dispatch key in the C10 key stack, which is ordered below AutogradCPU/AutogradCUDA. The Autograd key’s apply() calls gradient formulas, which re-enter the dispatcher — and those re-entries arrive at __torch_dispatch__ again, yielding a flat sequence of ATen calls spanning the entire forward + backward computation.
2.3 ProxyTorchDispatchMode
make_fx works by activating a dispatch mode called ProxyTorchDispatchMode, which is an instance of TorchDispatchMode — a context-manager-based version of __torch_dispatch__ that does not require a tensor subclass.
When ProxyTorchDispatchMode is active, every ATen operator call is intercepted. The mode:
1. Looks up the corresponding torch.fx proxy for each input tensor
2. Records a call_function node in an torch.fx.Graph object: graph.call_function(op, args=proxy_args)
3. Creates a new proxy for the output and associates it with the output tensor
4. Falls through to the real kernel (or, in tracing mode, to a FakeTensor computation)
The result after the traced function returns is a torch.fx.GraphModule whose graph attribute is the sequence of ATen calls executed.
2.4 How make_fx Works
make_fx(fn)(*example_inputs) returns a GraphModule representing fn’s operations. The call sequence is:
# Pseudocode for make_fx internals
def make_fx(fn, decomposition_table=None, tracing_mode="fake"):
def wrapper(*args):
# 1. Wrap inputs as FakeTensors (meta tensors with shape/dtype but no data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
fake_args = [fake_mode.from_tensor(a) for a in args]
# 2. Create an FX tracer and activate ProxyTorchDispatchMode
tracer = PythonKeyTracer()
proxy_args = [tracer.create_proxy("placeholder", f"arg_{i}", ...) for i, _ in enumerate(fake_args)]
with ProxyTorchDispatchMode(tracer, tracing_mode) as proxy_mode:
with fake_mode:
# 3. Run the function — every ATen call is intercepted by ProxyTorchDispatchMode
out = fn(*proxy_args)
# 4. Return the GraphModule
return torch.fx.GraphModule(tracer.root, tracer.graph)
return wrapperThe FakeTensorMode context ensures that shape and dtype propagation works correctly without allocating real memory. The ProxyTorchDispatchMode context captures every ATen call as an FX node.
2.5 Tracing Through the Autograd Engine
When AOTAutograd constructs the joint graph, it needs to capture not just the forward ops but also the backward ops. This is accomplished by calling torch.autograd.grad(fw_outputs, primals, tangents) inside the make_fx trace — while ProxyTorchDispatchMode is still active.
The autograd engine runs the backward pass in the normal way: it walks the tape accumulated during the forward and calls each gradient function’s apply() method. Each apply() call dispatches ATen ops — aten::cos for SinBackward, aten::mul for PowBackward, etc. Since ProxyTorchDispatchMode is still active on the dispatch stack, these ATen calls are also recorded as call_function nodes in the FX graph.
The result: the FX graph contains a contiguous sequence of nodes spanning forward ops followed by backward ops, all in topological order. This is the joint graph.
This problem establishes why FakeTensor mode is necessary during make_fx tracing.
Prerequisites: 2.4 How make_fx Works
Suppose we ran make_fx using real tensors (not FakeTensors). (a) What happens when we call torch.autograd.grad(fw_outputs, primals, tangents) inside the trace for a function \(f: \mathbb{R}^{1000 \times 1000} \to \mathbb{R}\)? (b) Why is this problematic for AOTAutograd, which is invoked at torch.compile time (before the user’s actual data is available)? (c) What property of FakeTensor allows it to support torch.autograd.grad during tracing?
Key insight: FakeTensor propagates shape/dtype metadata without materializing values — it can run the autograd graph structurally without requiring real gradient values.
Sketch:
(a) With real tensors, torch.autograd.grad executes the backward fully: it computes actual gradient values using real arithmetic, consuming real memory proportional to the model size.
(b) AOTAutograd is called at compile time when the user’s data is not yet available (torch.compile operates on example/fake inputs). Using real tensors would require materializing the full backward computation including allocating activation memory — this defeats the purpose of ahead-of-time compilation and would be prohibitively slow as a compilation overhead.
(c) FakeTensor implements every ATen operator as a metadata-only computation: it propagates output shapes, dtypes, and device from input metadata without executing arithmetic. This allows the autograd engine’s backward graph traversal to proceed (calling each gradient function’s structural logic) and ProxyTorchDispatchMode to record the op sequence, without ever computing real numbers.
3. The Joint Forward-Backward Graph 📐
3.1 Formal Definition
Let \(\mathbf{p} \in \mathbb{R}^d\) denote the concatenation of all primal inputs (model parameters and activations) and let \(\boldsymbol{\tau} \in \mathbb{R}^m\) denote the upstream tangent (gradient flowing into this subgraph from above). Let \(f: \mathbb{R}^d \to \mathbb{R}^m\) be the forward function.
Definition (Joint Graph). The joint forward-backward function is
\[J(\mathbf{p},\, \boldsymbol{\tau}) \;=\; \bigl(\,f(\mathbf{p}),\;\; \nabla_{\mathbf{p}}\langle f(\mathbf{p}),\, \boldsymbol{\tau}\rangle\,\bigr)\]
where \(\langle \cdot, \cdot \rangle\) is the standard inner product. The first component is the vector of forward outputs; the second is the vector-Jacobian product (VJP) — equivalently the gradient of the scalar \(\langle f(\mathbf{p}), \boldsymbol{\tau}\rangle\) with respect to \(\mathbf{p}\).
The FX graph that make_fx produces for \(J\) is called the joint graph. Its nodes are ATen operations; some compute forward outputs, others compute gradient outputs. Neither group is annotated — they are simply ordered by data dependency.
3.2 Graph Signature
The joint graph has the following FX placeholder (input) and output structure:
Inputs: [p_0, p_1, ..., p_{d-1}, tau_0, tau_1, ..., tau_{m-1}]
^--- primals ---^ ^------- tangents --------^
Outputs: ([fw_out_0, ..., fw_out_k], [grad_p_0, ..., grad_p_{d-1}])
The primals are the original function arguments. The tangents are synthetic inputs representing upstream gradients — in the loss computation case, typically a scalar 1.0 tangent for the loss tensor.
3.3 Concrete Example: sin + sum
Consider the function:
import torch
from torch._functorch.aot_autograd import aot_function
def fn(x):
y = torch.sin(x)
return y.sum()AOTAutograd traces the joint graph \(J(x, \tau)\) where \(x\) is the primal and \(\tau \in \mathbb{R}\) is the scalar tangent for sum’s output. The joint graph contains the following ATen nodes (pseudocode FX IR):
%x : [#users=2] = placeholder[target=x]
%tau : [#users=1] = placeholder[target=tangent_0]
# Forward ops
%sin : [#users=2] = call_function[target=aten.sin.default](%x)
%sum_1 : [#users=1] = call_function[target=aten.sum.default](%sin)
# Backward ops (SinBackward: d/dx sin(x) = cos(x))
%cos : [#users=1] = call_function[target=aten.cos.default](%x)
%expand : [#users=1] = call_function[target=aten.expand.default](%tau, [shape_of_sin])
%mul : [#users=1] = call_function[target=aten.mul.Tensor](%expand, %cos)
output ([%sum_1], [%mul])
Key observations:
- %x appears twice: once feeding %sin (forward) and once feeding %cos (backward). After partitioning, %x will be a saved tensor bridging the forward and backward graphs.
- %cos and %mul are the backward nodes. Compiled together with %expand, they fuse into a single kernel.
- The sin activation %sin is used in both %sum_1 (forward) and needs to be reachable for potential backward use, but the min-cut partitioner will determine whether to save it or recompute it.
3.4 derivatives.yaml and JVP Formulas
PyTorch’s gradient formulas are defined in tools/autograd/derivatives.yaml. Each entry specifies the VJP formula for an ATen op in a Python-like DSL. For example:
- name: sin(Tensor self) -> Tensor
self: grad * self.cos()When make_fx traces torch.autograd.grad(...), the autograd engine looks up these formulas and dispatches the ATen ops they specify (here aten::mul and aten::cos). These dispatched ops arrive at ProxyTorchDispatchMode and become FX nodes. The joint graph is therefore an unrolled expansion of the derivatives.yaml formulas for all ops in the forward.
This problem exercises reading off the joint graph structure for a realistic forward computation.
Prerequisites: 3.3 Concrete Example: sin + sum, 3.4 derivatives.yaml and JVP Formulas
Let \(f(W_1, W_2, x) = \text{relu}(W_2 \cdot \text{relu}(W_1 \cdot x))\) where \(W_1 \in \mathbb{R}^{h \times d}\), \(W_2 \in \mathbb{R}^{o \times h}\), \(x \in \mathbb{R}^d\). (a) List the ATen ops that appear in the joint graph (forward section). (b) List the ATen ops that appear in the backward section, citing the corresponding derivatives.yaml entry for each. (c) Which tensors from the forward are referenced by the backward section, and are therefore candidates for the saved-tensor set?
Key insight: Matrix multiply backward introduces both a left and right product; relu backward requires the sign of the pre-activation.
Sketch:
(a) Forward ops: aten.mm(W1, x) → aten.relu → aten.mm(W2, h1) → aten.relu → aten.sum (or loss).
(b) Backward ops: ReluBackward uses aten.threshold_backward(grad, h, 0) which needs the pre-relu activation h to compute the sign mask. MmBackward uses aten.mm(grad_out, W.T) for the input gradient and aten.mm(input.T, grad_out) for the weight gradient — needing both the input and the weight from the forward.
(c) Saved tensor candidates: h1_pre_relu (pre-activation after first mm), h2_pre_relu (pre-activation after second mm), h1 (post-relu activation needed for dW2), x (needed for dW1), W1, W2 (needed for propagating gradients back through mm). The partitioner will choose to save a subset and recompute the rest.
4. Functionalization in Depth 📐
4.1 Why Compilers Require Functional Graphs
Backend compilers such as TorchInductor operate on a functional intermediate representation: every operation takes tensors as inputs and produces new tensors as outputs, with no side effects. Two categories of PyTorch operations violate this requirement:
- In-place mutations:
x.add_(y),x.relu_(),weight.data.copy_(...)— these modify existing tensor storage. - View aliasing:
y = x[2:5],y = x.view(...)—yandxshare underlying storage, so a write toyis also a write tox.
A compiler that blindly inlines either category risks miscompiling: it might reorder a mutation relative to a read of the same storage, or it might eliminate a copy that was actually observable from another alias.
Functionalization eliminates both categories by rewriting the program into an equivalent one that uses only copy-on-write semantics.
4.2 FunctionalTensor
FunctionalTensor is a tensor subclass that intercepts in-place and view operations. Internally it wraps a real (or fake) tensor called the functional storage. When an in-place op arrives:
# User writes: x.add_(y)
# FunctionalTensor intercepts and does:
x_new = x_functional_storage + y # pure functional op, recorded as FX node
x._functional_storage = x_new # rebind the backing tensor; no HBM write yetThe functional storage is rebind, not mutated. The original tensor object x is still reachable by user code (it still satisfies x is x), but internally it now points to a different tensor value.
View operations are similarly intercepted: y = x.view(...) creates a new FunctionalTensor that remembers (x, view_op) as its view chain. No aliased storage is created.
4.3 The functionalize() Transform
torch.func.functionalize(fn) is a higher-order function that runs fn inside a FunctionalTensorMode context. This mode:
1. Wraps all input tensors in FunctionalTensor
2. Runs fn — all in-place ops and views are intercepted
3. After fn returns, calls sync_functional_tensor on each output to apply the pending mutations: the rebindings accumulated in the functional storage are committed to the FX graph as explicit copy_ or functional-equivalent nodes
The output is a graph whose every operation is pure.
4.4 Mutation Classification
After functionalization, AOTAutograd classifies each input according to the MutationType enum defined in torch/_functorch/_aot_autograd/schemas.py:
| Class | Meaning | Graph Treatment |
|---|---|---|
NOT_MUTATED |
Input not modified during forward | No special handling; input simply flows through |
MUTATED_IN_GRAPH |
Mutation expressible as a functional op (non-leaf tensor data mutation) | Embedded as a copy_ at the end of the forward graph; the updated value is a graph output |
MUTATED_OUT_GRAPH |
Mutation not expressible functionally (e.g., leaf parameter metadata mutation) | Applied via an epilogue copy_ or as_strided_ at runtime outside the compiled graph |
The runtime wrapper uses this classification to reconstruct correct mutation semantics after the compiled graph returns. Specifically, for MUTATED_IN_GRAPH inputs, the compiled graph returns the updated tensor as an extra output; the wrapper applies input.copy_(updated_value) to propagate the mutation back to the original Python tensor object.
4.5 View Aliasing and ViewAndMutationMeta
ViewAndMutationMeta is the central dataclass (in torch/_functorch/_aot_autograd/schemas.py) that records aliasing and mutation information collected during the functionalization pass. Its key fields:
input_info: A list ofInputAliasInfoobjects (one per input) recording whether the input was mutated, and if so what type.output_info: A list ofOutputAliasInfoobjects recording whether each output is:OutputType.non_alias— a freshly computed tensorOutputType.alias_of_input— a view of one of the inputsOutputType.is_input— literally is one of the inputsOutputType.alias_of_intermediate— a view of an intermediate computed tensor
mutated_inp_runtime_indices: Indices of inputs that requirecopy_epilogues at runtime.
After the functional compiled graph executes, the runtime wrapper uses ViewAndMutationMeta to reconstruct the correct aliasing: for outputs of type alias_of_input, it replays the recorded view chain to produce a tensor that genuinely aliases the input’s storage, satisfying user expectations.
4.6 Toy Demonstration
Consider a function with an in-place relu:
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._functorch.aot_autograd import functionalized_rng_runtime_epilogue
def fn_inplace(x):
x.relu_() # in-place mutation
return x * 2.0
# Before functionalization: make_fx traces through in-place ops literally
gm_before = make_fx(fn_inplace)(torch.randn(4))
print(gm_before.code)
# def forward(self, x):
# relu_ = torch.ops.aten.relu_.default(x); x = None
# mul = torch.ops.aten.mul.Tensor(relu_, 2.0)
# return mul
# After functionalization: in-place ops replaced by functional equivalents
from torch._functorch.aot_autograd import aot_function
def print_compile(gm, _):
print(gm.code)
return gm
aot_function(fn_inplace, fw_compiler=print_compile)(torch.randn(4))
# def forward(self, x):
# relu = torch.ops.aten.relu.default(x) # functional relu
# copy_ = torch.ops.aten.copy_.default(x, relu) # mutation epilogue
# mul = torch.ops.aten.mul.Tensor(relu, 2.0)
# return (copy_, mul)The in-place relu_ is replaced by a functional relu followed by a copy_ epilogue. The graph is now pure: the relu node has no side effects, and the copy_ makes the mutation explicit and deferrable.
This problem exercises identifying which mutation class applies to different program patterns.
Prerequisites: 4.4 Mutation Classification, 4.5 View Aliasing and ViewAndMutationMeta
For each of the following patterns, identify the mutation class and describe what the runtime wrapper does after the compiled graph returns:
(a) A forward function that does hidden.add_(bias) where hidden is an intermediate (not a user-visible input).
(b) A forward function that does param.data.copy_(new_val) where param is a model parameter passed as an input.
(c) A forward function that returns x[2:5] where x is an input tensor (a view of the input).
Key insight: The classification depends on whether the mutation target is a user-visible input and whether it can be expressed functionally; aliasing is handled via output type classification, not mutation type.
Sketch:
(a) hidden.add_(bias) where hidden is an intermediate: This is an in-graph op on a non-input tensor. Functionalization replaces it with hidden_new = hidden + bias and any downstream uses of hidden are rewritten to use hidden_new. No mutation epilogue is needed because the user never holds a reference to hidden. Class: effectively NOT_MUTATED from the perspective of inputs.
(b) param.data.copy_(new_val) where param is a passed input: param is a user-visible input, so this is MUTATED_IN_GRAPH. The compiled graph returns new_val as an extra output, and the runtime wrapper calls param.copy_(new_val) to make the mutation visible to the caller.
(c) return x[2:5] where x is an input: This is not a mutation at all — it’s a view aliasing output. ViewAndMutationMeta records this output as OutputType.alias_of_input. The compiled graph returns the underlying data (or a non-aliased copy); the runtime wrapper replays the x[2:5] view chain to produce a tensor that aliases x’s storage correctly.
5. Graph Partitioning via Min-Cut Rematerialization 📐
5.1 The Partitioning Problem
After the joint graph is produced and functionalized, it must be partitioned into two disjoint subgraphs:
- Forward graph \(G_f\): executed at forward time; takes primals as inputs; produces user outputs and a set of saved tensors \(\mathcal{S}\)
- Backward graph \(G_b\): executed at
.backward()time; takes saved tensors \(\mathcal{S}\) and tangents as inputs; produces gradients with respect to the primals
The interface between \(G_f\) and \(G_b\) is exactly \(\mathcal{S}\). Every element of \(\mathcal{S}\) must be stored in memory between forward and backward — it represents activation memory. Minimizing \(|\mathcal{S}|\) (by bytes) reduces peak memory.
The core tension: a value needed by \(G_b\) must either be (a) in \(\mathcal{S}\) (saved, at memory cost) or (b) recomputed inside \(G_b\) (at compute cost). The partitioner must choose optimally.
5.2 Min-Cut Formulation
The min_cut_rematerialization_partition function in torch/_functorch/partitioners.py formulates partitioning as a minimum \(s\)-\(t\) cut on a flow network.
Construction of the flow network:
Each FX node \(v\) in the joint graph is split into two flow network nodes, \(v_\text{in}\) and \(v_\text{out}\), connected by an edge with capacity \(w(v)\):
\[\text{edge}(v_\text{in} \to v_\text{out}),\quad \text{capacity} = w(v)\]
Data dependencies between nodes are represented by infinite-capacity edges:
\[\text{if } u \text{ feeds } v: \quad \text{edge}(u_\text{out} \to v_\text{in}),\quad \text{capacity} = \infty\]
A source node \(s\) connects to all primal input nodes. A sink node \(t\) collects from all nodes required by the backward section (identified during graph tracing). Infinite-capacity edges from \(s\) to primal inputs and from required-backward nodes to \(t\) ensure these nodes are never cut.
The min-cut of this network partitions the nodes into a reachable set (associated with \(G_f\)) and an unreachable set (associated with \(G_b\)). A node \(v\) is cut (i.e., appears in the saved set \(\mathcal{S}\)) if the edge \(v_\text{in} \to v_\text{out}\) is in the cut. Minimizing the cut total capacity minimizes the total bytes of saved tensors, subject to every backward requirement being satisfiable.
5.3 Edge Weight Design
The edge weight \(w(v)\) for node \(v\) is:
\[w(v) = \text{bytes}(v) \cdot \left(1.1^{\min(\delta(v),\, 100)}\right) \cdot \mathbb{1}[\text{not fusible with all users}]\]
where:
- \(\text{bytes}(v)\) is the memory footprint of \(v\)’s output tensor (from tensor_meta)
- \(\delta(v) \in \mathbb{Z}_{\geq 0}\) is \(v\)’s distance-from-backward (nodes closer to the backward get lower weight, biasing toward saving them nearer the cut)
- The \(\mathbb{1}[\text{not fusible}]\) factor doubles the weight for nodes that cannot be fused with all their users — making them more expensive to cut (i.e., more likely to be recomputed)
Nodes that must not be cut are assigned \(w(v) = \infty\), which guarantees they appear in the saved set.
5.4 Recomputation Heuristics
Always ban from recomputation (assigned \(w(v) = \infty\)):
- Random ops: aten::native_dropout, aten::rand_like, aten::randn_like — recomputing would give different values
- Reduction ops that strongly compress data (output size \(\leq\) input size / 4): aten::mean, aten::sum with large reduction dimensions — cheap to save, expensive to recompute
- Compute-intensive ops: aten::mm, aten::bmm, aten::addmm, aten::convolution, aten::upsample_bilinear2d
Always allow recomputation (finite \(w(v)\), typically small):
- View ops: aten::view, aten::reshape, aten::transpose, aten::permute — zero FLOPs, pure metadata
- Elementwise activations: aten::relu, aten::gelu, aten::silu, aten::tanh
- Scalar arithmetic: aten::add, aten::mul, aten::sub, aten::div (elementwise)
- aten::getitem, aten::scalar_tensor
In recent PyTorch versions, min_cut_rematerialization_partition has been extended with an activation_memory_budget parameter. When set to a value in \([0, 1]\), it runs a 0–1 knapsack solver (dynamic programming, greedy, or ILP variants from torch/_functorch/_activation_checkpointing/knapsack.py) to find the cheapest set of nodes to recompute while keeping total saved memory within the budget. This provides a principled memory–compute trade-off dial.
5.5 Comparison to torch.utils.checkpoint
| Property | torch.utils.checkpoint |
AOTAutograd min-cut |
|---|---|---|
| Granularity | User-specified segments | Per-node in FX graph |
| Timing | Runtime (eager) | Compile-time (static) |
| Optimality | Suboptimal (user decides boundaries) | Globally optimal within the graph |
| Fusion awareness | None | Yes (fusible recomputed ops are compiled together with their consumers) |
| Random op handling | User must handle manually | Automatically banned from recompute |
| API | checkpoint(fn, *args) wrapper |
Automatic inside torch.compile |
The fusion-awareness point is the critical advantage. When AOTAutograd decides to recompute, say, relu in the backward, it is recomputed inside \(G_b\)’s compiled graph. The recomputed relu is adjacent to the backward nodes that consume it, so the compiler can fuse them into a single kernel — the recomputation has near-zero marginal cost relative to saving the activation.
This problem tests understanding of which structural properties the partitioner must guarantee.
Prerequisites: 5.2 Min-Cut Formulation, 5.4 Recomputation Heuristics
- Prove that the min-cut construction guarantees that \(G_b\) can always compute all required gradients — i.e., there is no backward node that lacks a valid input. (b) Explain why assigning \(w(v) = \infty\) to a
native_dropoutnode does not mean thatnative_dropout’s output is always saved; instead, describe the two possibilities. (c) If a node \(v\) is a view op (zero bytes to save), what does the flow network assign as \(w(v)\), and what does the resulting cut placement imply?
Key insight: The min-cut respects data dependency via infinite-capacity dependency edges; banned nodes are guaranteed to appear in the saved set; zero-weight nodes are always recomputed.
Sketch:
(a) Every backward requirement node has an infinite-capacity edge to \(t\). In any finite-cost cut, this edge cannot be cut — so the requirement node is always in the unreachable (backward) partition. By the flow network construction, every predecessor of this node in the computation either has its edge cut (is in \(\mathcal{S}\)) or is also in the backward partition and is recomputed. Inductively, the backward graph has all the values it needs.
(b) \(w(v) = \infty\) for native_dropout means the edge \(v_\text{in} \to v_\text{out}\) cannot be in the min-cut (its capacity is too large to cut). This means the node is either: (i) placed entirely in \(G_f\) (used only in the forward) so it is not in \(\mathcal{S}\) and not recomputed — its output just flows forward; or (ii) the node itself is in \(G_f\) but its output is demanded by \(G_b\), forcing it into \(\mathcal{S}\) via the infinite-capacity edge to \(t\). In both cases, recomputation is forbidden.
(c) For a view op, \(\text{bytes}(v) = 0\), so \(w(v) = 0\) regardless of other factors. A zero-capacity edge is always in every min-cut (cutting it is free). Thus the view node always lands in the saved set — but since saving a view costs zero bytes, this is equivalent to always recomputing it: the backward graph replays the view op at negligible cost.
6. The AOTAutograd Runtime Wrapper 🔑
After partitioning, AOTAutograd produces two GraphModule objects — \(G_f\) and \(G_b\) — which are passed to the configured fw_compiler and bw_compiler respectively to produce optimized callables. These callables are then wrapped in a Python torch.autograd.Function subclass called CompiledFunction.
6.1 The CompiledFunction Structure
CompiledFunction is a dynamically-generated subclass of torch.autograd.Function. Its three methods correspond exactly to the three phases of PyTorch’s autograd engine:
class CompiledFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, *primals):
# Run the compiled forward graph
fw_outs = compiled_fw(*primals)
# Split: user-visible outputs vs. saved tensors
user_outputs = fw_outs[:num_user_outputs]
saved_tensors = fw_outs[num_user_outputs:]
ctx.save_for_backward(*saved_tensors)
return tuple(user_outputs)
@staticmethod
def setup_context(ctx, inputs, output):
# Populated by save_for_backward in forward above
pass
@staticmethod
def backward(ctx, *grad_outputs):
saved_tensors = ctx.saved_tensors
# Run the compiled backward graph
gradients = compiled_bw(*saved_tensors, *grad_outputs)
return tuple(gradients)The key design choice: saved tensors are made explicit at the graph level. Instead of each ATen backward node individually stashing intermediates (the default eager behavior), the partitioner determines the minimal set of saved tensors up front, and \(G_f\) returns them as explicit outputs. This eliminates the overhead of per-node save_for_backward calls and gives the compiler full visibility into the memory footprint.
6.2 Two-Stage Pipeline
The AOTAutograd pipeline has two stages that the wrappers implement:
Stage 1 — Graph Capture:
flowchart LR
A["User function fn"]
B["AOTDedupeWrapper
(remove duplicate inputs)"]
C["AOTSyntheticBaseWrapper
(merge aliased inputs)"]
D["create_functionalized_fn
(wrap in FunctionalTensorMode)"]
E["make_fx + create_joint
(joint graph)"]
A --> B --> C --> D --> E
Stage 2 — Compilation and Wrapping:
flowchart LR
E["Joint Graph"]
F["min_cut_rematerialization_partition
or default_partition"]
G["fw_compiler"]
H["bw_compiler"]
I["CompiledFunction wrapper"]
J["RuntimeWrapper
(mutation epilogues)"]
E --> F
F --> G
F --> H
G --> I
H --> I
I --> J
6.3 Input Mutation Epilogue
For inputs classified as MUTATED_IN_GRAPH, the compiled \(G_f\) returns the updated tensor as an extra output (beyond the user-visible outputs and saved tensors). The RuntimeWrapper — the outermost Python wrapper — intercepts the compiled graph’s outputs and applies:
for i, inp_idx in enumerate(mutated_inp_runtime_indices):
original_input = inputs[inp_idx]
updated_value = fw_outputs[num_user_outputs + num_saved + i]
original_input.copy_(updated_value)This reconstructs the mutation semantics that the user’s code expects (e.g., optimizer.step() updating parameters in-place) while keeping the compiled graph itself mutation-free.
6.4 Higher-Order Differentiation
Because CompiledFunction is a proper torch.autograd.Function subclass, it participates in PyTorch’s autograd engine normally. This means:
Second-order gradients:
torch.autograd.grad(loss, params, create_graph=True)callsCompiledFunction.forward, which returns outputs withgrad_fnset. Calling.backward()on those outputs differentiates throughCompiledFunction.backward— this is standard autograd-of-autograd.Composability with
torch.func:torch.func.grad(compiled_fn)wrapsCompiledFunctionin functorch’s functional autograd infrastructure, enabling composition liketorch.func.grad(torch.func.grad(compiled_fn)).
The CompiledFunction.backward graph is itself eager Python code — it is not re-traced for higher-order differentiation. If you need compiled second-order gradients, use Compiled Autograd (torch.compile(torch.autograd.grad(...))) which extends the AOTAutograd approach one level further.
This problem traces the lifecycle of a saved tensor through AOTAutograd’s wrapper stack.
Prerequisites: 6.1 The CompiledFunction Structure, 6.3 Input Mutation Epilogue
Suppose a forward function has the graph \(G_f\) with outputs [user_out, saved_cos_x] and the backward graph \(G_b\) takes [saved_cos_x, tangent] as inputs. Trace the lifecycle of saved_cos_x:
(a) When is saved_cos_x allocated (what event causes the memory allocation)?
(b) Where is it stored between forward and backward?
(c) When is it freed?
(d) If torch.autograd.grad(..., retain_graph=False) is used (the default), what does PyTorch do with the CompiledFunction’s saved tensors after backward completes?
Key insight: AOTAutograd’s saved tensors follow the same lifecycle as any save_for_backward tensor in standard PyTorch autograd.
Sketch:
(a) saved_cos_x is allocated when compiled_fw executes and returns it as part of fw_outs. The memory is materialized as part of the compiled forward kernel’s output allocation.
(b) It is stored in ctx.saved_tensors — the CompiledFunction’s autograd context, which lives in the autograd graph node that PyTorch creates for this Function application. In practice, it is a Python list held by the AccumulateGrad node.
(c) After CompiledFunction.backward calls ctx.saved_tensors, PyTorch decrements the reference count. With retain_graph=False, the autograd node is freed after backward, dropping the last reference to saved_cos_x, which then gets garbage-collected (or returned to the memory allocator).
(d) With retain_graph=False, PyTorch frees the entire autograd graph rooted at the CompiledFunction node after backward completes — including all saved tensors. saved_cos_x is released. A subsequent call to loss.backward() would fail with “Trying to backward through the graph a second time.”
7. Relationship to torch.func 💡
7.1 Shared Infrastructure
torch.func (formerly functorch) and AOTAutograd are not separate implementations — they share the same underlying infrastructure:
- Both use
make_fxwithProxyTorchDispatchModefor graph capture - Both use
FunctionalTensor/FunctionalTensorModefor functionalization - Both use the
create_jointmachinery to combine forward and backward passes - Both dispatch through the same C10 dispatch stack
The difference is purely in when and how the transformation is applied.
7.2 vmap and BatchedTensor
torch.func.vmap (vectorized map) operates analogously to functionalization but in the batching dimension. It activates a BatchedTensor subclass — a dispatch mode that intercepts every ATen op and replaces it with its batched equivalent (e.g., aten::mm → aten::bmm). The dispatch stack position is the same (__torch_dispatch__ level), so vmap composes naturally with make_fx:
batched_fn = torch.func.vmap(fn)
gm = make_fx(batched_fn)(batched_example)
# gm contains batched ATen opsThis composability is the reason torch.func.jacrev, torch.func.hessian, and torch.func.grad nest correctly — each transform installs its own dispatch mode, and the modes stack.
7.3 AOTAutograd as Ahead-of-Time torch.func.grad
torch.func.grad(f)(*args) applies a VJP transform to f and executes it immediately with *args. The transformation is reapplied on every call.
AOTAutograd applies a VJP transform (via create_joint) once, ahead of time, producing a static graph that is compiled and cached. The compiled function is then called for every subsequent forward + backward pass.
The duality is:
| Property | torch.func.grad |
AOTAutograd |
|---|---|---|
| Transform application | Per-call (runtime) | Once at compile time |
| Result | Eager callable | Compiled GraphModule |
| Composability | Arbitrary nesting | Single pass |
| Overhead | Transform overhead every call | Compilation overhead once |
| Graph visibility | No static graph | Full static graph; enables fusion |
Surprisingly, both approaches produce equivalent outputs — AOTAutograd can be understood as caching the result of torch.func.grad for a fixed function signature.
One might ask: why not simply JIT-cache torch.func.grad(f)? The answer is that torch.func.grad is still a Python function — it calls make_fx each time or executes eagerly. AOTAutograd combines the tracing with functionalization, partitioning, and backend compilation in a single pipeline, producing a fused kernel. torch.func.grad alone produces an FX graph but does not fuse kernels. The backend compiler (Inductor) is what converts the graph into a fused Triton kernel.
This problem explores what happens when vmap and grad are composed under AOTAutograd.
Prerequisites: 7.2 vmap and BatchedTensor, 7.3 AOTAutograd as Ahead-of-Time torch.func.grad
Consider f(x) = (x**2).sum() where x is a 1D tensor. We want to compute the per-sample gradient across a batch: torch.func.vmap(torch.func.grad(f))(batch_x) where batch_x has shape (B, D).
(a) In what order do the dispatch modes activate on the stack during the vmap+grad trace?
(b) What is the shape of the FX graph output for the batched-per-sample-gradient computation?
(c) If this combined transform were compiled through AOTAutograd (i.e., torch.compile(torch.func.vmap(torch.func.grad(f)))), what does the joint graph look like at the ATen level, and which ATen ops appear?
Key insight: vmap and grad compose by stacking dispatch modes; the joint graph for the composed transform contains both the batching expansion and the gradient formula ops.
Sketch:
(a) vmap activates BatchedTensor mode (outermost); grad activates ProxyTorchDispatchMode (inner, for the gradient tracing). When the inner function executes x**2, it goes through: Python API → autograd → batching (BatchedTensor) → ProxyTorchDispatchMode. The batching dimension is handled first, then the proxy recording captures the batched op.
(b) The output has shape (B, D) — for each sample in the batch, the gradient of f w.r.t. that sample’s x is a vector of shape (D,).
(c) At the ATen level, the joint graph for vmap(grad(f)) applied to batch_x contains: aten::mul(batch_x, 2.0) (gradient of x**2 is 2x), perhaps an aten::sum or identity from the sum’s backward, and then the result has shape (B, D). Since f(x) = (x**2).sum(), grad(f)(x) = 2*x elementwise. The batched version is aten::mul(batch_x, scalar_tensor(2.0)) — a single batched multiply, perfectly fusible.
References
| Reference | Brief Summary | Link |
|---|---|---|
| Ansel et al., “PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation and Graph Compilation” (ASPLOS 2024) | Primary source covering the full torch.compile stack including AOTAutograd’s joint graph, functionalization, and partitioning | dl.acm.org |
| PyTorch, “AOT Autograd: How to use and optimize?” (functorch docs) | Official tutorial demonstrating aot_function, graph printing, NNC compilation, and recomputation benchmarks |
docs.pytorch.org |
| DeepWiki, “AOT Autograd and Functionalization” | Detailed architectural breakdown of the two-stage pipeline, FunctionalTensor, ViewAndMutationMeta, and wrapper types | deepwiki.com |
| PyTorch Developer Forum, “What (and Why) is torch_dispatch?” | Authoritative explanation of __torch_dispatch__’s position in the dispatch stack and its relationship to autograd |
dev-discuss.pytorch.org |
| PyTorch Developer Forum, “Difference between torch_function and torch_dispatch” | Precise comparison of the two extension hooks, their firing points, and what they can/cannot intercept | dev-discuss.pytorch.org |
| PyTorch Developer Forum, “How does torch.compile work with autograd?” | Core explanation of how AOTAutograd traces backward via FakeTensor + dispatch and wraps in autograd.Function | dev-discuss.pytorch.org |
torch/_functorch/aot_autograd.py (PyTorch source) |
Canonical implementation: aot_function, aot_module, joint graph construction, runtime wrapper |
github.com |
torch/_functorch/partitioners.py (PyTorch source, via functorch docs) |
Implementation of min_cut_rematerialization_partition: NetworkX flow network, edge weights, recomputation heuristics |
docs.pytorch.org (0.2.0 module source) |
torch/_functorch/_aot_autograd/schemas.py (via DeepWiki) |
Definition of ViewAndMutationMeta, MutationType, InputAliasInfo, OutputAliasInfo |
github.com (pytorch) |
| depyf, “A Walk Through Example of torch.compile” | Concrete walkthrough of CompiledFunction structure and saved tensor lifecycle |
depyf.readthedocs.io |
| functorch GitHub README | Overview of composable function transforms (vmap, grad, jacrev) and their relationship to AOTAutograd | github.com |