JAX: Functional Numerical Computing
Why JAX?
JAX is Google's library for high-performance numerical computing.
It provides a NumPy-compatible API with three superpowers:
automatic differentiation (jax.grad), JIT compilation
(jax.jit via XLA), and automatic vectorization (jax.vmap).
Unlike Numba, JAX operates on pure functions and immutable arrays,
following a functional programming paradigm.
This makes JAX ideal for gradient-based optimization, machine learning research, and scientific computing where composable transformations are valuable.
Definition: JAX NumPy: Drop-in Replacement
JAX NumPy: Drop-in Replacement
jax.numpy mirrors the NumPy API but operates on immutable
JAX arrays (jax.Array). Most NumPy code works by replacing
import numpy as np with import jax.numpy as jnp:
import jax.numpy as jnp
x = jnp.linspace(0, 2 * jnp.pi, 1000)
y = jnp.sin(x) * jnp.exp(-0.1 * x)
Key difference: JAX arrays are immutable. In-place operations
like x[0] = 5 raise an error. Use x = x.at[0].set(5) instead.
Immutability enables JAX to trace and transform functions safely. It guarantees that function transformations (grad, jit, vmap) produce correct results.
Definition: XLA Compilation with jax.jit
XLA Compilation with jax.jit
jax.jit compiles a function using XLA (Accelerated Linear Algebra),
Google's optimizing compiler for linear algebra:
import jax
@jax.jit
def selu(x, alpha=1.67, lam=1.05):
return lam * jnp.where(x > 0, x, alpha * (jnp.exp(x) - 1))
XLA fuses operations, eliminates temporaries, and optimizes memory
layout. Unlike Numba, jax.jit compiles entire computational
graphs, not individual loops.
The first call triggers tracing and compilation. JAX traces the function with abstract values to build a computation graph (jaxpr), then compiles it via XLA.
Definition: Automatic Differentiation with jax.grad
Automatic Differentiation with jax.grad
jax.grad computes exact gradients of scalar-valued functions via
reverse-mode automatic differentiation (backpropagation):
def loss(w, x, y):
pred = jnp.dot(x, w)
return jnp.mean((pred - y) ** 2)
grad_loss = jax.grad(loss) # d(loss)/d(w)
w_grad = grad_loss(w, x, y) # same shape as w
Higher-order derivatives compose naturally:
jax.grad(jax.grad(f)) gives the second derivative.
jax.jacobian and jax.hessian handle vector-valued functions.
Automatic differentiation is not finite differences (which has truncation error) and not symbolic differentiation (which produces expression swell). AD computes exact derivatives at machine precision.
Definition: Automatic Vectorization with jax.vmap
Automatic Vectorization with jax.vmap
jax.vmap transforms a function that operates on single examples
into one that operates on batches, without manual broadcasting:
def predict_single(w, x):
return jnp.dot(w, x)
# Vectorize over batch of x, keep w fixed
predict_batch = jax.vmap(predict_single, in_axes=(None, 0))
predictions = predict_batch(w, X_batch) # (batch, features) -> (batch,)
in_axes specifies which argument axis to map over (None for
broadcast, 0 for first axis, etc.).
Definition: Functional Purity in JAX
Functional Purity in JAX
JAX transformations require pure functions: given the same inputs, the function must always produce the same outputs with no side effects.
Violations include:
- Global state or mutable closures
- In-place array mutation (
x[i] = v) - Python
print()inside@jax.jit(runs only during tracing) - Non-deterministic random numbers (use
jax.randomwith explicit keys)
# Wrong: stateful RNG
def bad(): return jnp.array(np.random.randn(3))
# Right: explicit key
def good(key): return jax.random.normal(key, shape=(3,))
JAX's random number generation uses a split-key model. Each
random call consumes a key: key, subkey = jax.random.split(key).
Theorem: Reverse-Mode AD Complexity
For a function composed of elementary operations, reverse-mode automatic differentiation computes the full gradient in:
independent of . This is the cheap gradient principle: the gradient costs at most a constant factor more than the function evaluation itself.
Forward mode propagates derivatives alongside computation ( cost for inputs). Reverse mode propagates backward from the output ( cost regardless of ). This is why backpropagation in neural networks scales to millions of parameters.
Forward pass
Evaluate and record the computation graph (tape). Cost: (same as evaluating ).
Backward pass
Traverse the tape in reverse, applying the chain rule at each node. Each elementary operation contributes a constant number of multiplications. Total cost: at most additional.
Example: Gradient Descent Optimization with JAX
Minimize the Rosenbrock function
using gradient descent
with jax.grad.
Define function and gradient
import jax
import jax.numpy as jnp
def rosenbrock(params):
x, y = params
return (1 - x)**2 + 100*(y - x**2)**2
grad_fn = jax.jit(jax.grad(rosenbrock))
Gradient descent loop
params = jnp.array([-1.0, 1.0])
lr = 0.001
for i in range(5000):
g = grad_fn(params)
params = params - lr * g
print(f"Minimum at: {params}") # close to [1.0, 1.0]
JAX computes the exact gradient at machine precision, and
@jax.jit compiles the gradient computation for speed.
Example: Batched Matrix-Vector Products with vmap
Given a batch of matrices and corresponding vectors , compute all products without explicit loops.
Using jax.vmap
def matvec(A, x):
return A @ x
batch_matvec = jax.vmap(matvec)
# A_batch: (B, m, n), x_batch: (B, n)
results = batch_matvec(A_batch, x_batch) # (B, m)
vmap automatically handles the batch dimension. Under the
hood, XLA compiles this into a single batched GEMV call,
avoiding any Python-level loops.
Automatic Differentiation Explorer
Visualize a function and its JAX-computed gradient. Compare with finite-difference approximation to see AD accuracy.
Parameters
JAX vs. PyTorch
| Feature | JAX | PyTorch |
|---|---|---|
| Paradigm | Functional (pure functions) | Object-oriented (nn.Module) |
| Differentiation | jax.grad (function transform) | torch.autograd (tape-based) |
| JIT Compilation | jax.jit via XLA | torch.compile via TorchDynamo |
| Auto-vectorization | jax.vmap (built-in) | torch.vmap (since 2.0) |
| GPU Support | TPU + GPU + CPU | GPU + CPU |
| Mutability | Immutable arrays | Mutable tensors |
| Random Numbers | Explicit key splitting | Global state (torch.manual_seed) |
| Ecosystem | Flax, Optax, Haiku | torchvision, torchaudio, HuggingFace |
| Best For | Research, custom algorithms | Production ML, rapid prototyping |
Quick Check
Why are JAX arrays immutable?
To save memory
To ensure function transformations (grad, jit, vmap) are correct
Because Python lists are immutable
To be compatible with TensorFlow
Correct. Immutability guarantees referential transparency, which is required for JAX's function transformations to produce correct results.
Common Mistake: Side Effects Inside jax.jit
Mistake:
Putting print() or other side effects inside a @jax.jit function
and expecting them to execute on every call:
@jax.jit
def f(x):
print("called!") # Only prints during tracing, not execution
return x ** 2
Correction:
Use jax.debug.print() for debugging inside JIT. For logging,
use jax.debug.callback(). Remember: @jax.jit traces the
function once and compiles it; Python side effects only run during
tracing.
Historical Note: From Autograd to JAX
2015-2018JAX evolved from Autograd, a Python library for automatic differentiation created by Dougal Maclaurin, David Duvenaud, and Matt Johnson at Harvard in 2015. Autograd could differentiate native Python and NumPy code but lacked JIT compilation. In 2018, the same team at Google Brain combined Autograd's tracing approach with XLA compilation to create JAX, achieving both expressiveness and performance.
Why This Matters: JAX for Channel Estimation
JAX's jax.grad and jax.jit are ideal for iterative channel
estimation algorithms. For example, maximum-likelihood channel
estimation requires minimizing over .
With JAX, you write the loss function in NumPy-like syntax,
and jax.grad provides the exact gradient for gradient descent
or L-BFGS, while jax.jit compiles the entire optimization
loop for GPU execution.
See full treatment in Chapter 16
Key Takeaway
JAX is a functional numerical computing library: write pure
functions with jax.numpy, then compose transformations
(jit for speed, grad for derivatives, vmap for batching).
Its strength is composability β jax.jit(jax.vmap(jax.grad(f)))
gives you a compiled, batched gradient function in one line.
XLA
Accelerated Linear Algebra: Google's domain-specific compiler for linear algebra that fuses operations, optimizes memory, and targets CPUs, GPUs, and TPUs.
Related: JIT Compilation
Automatic Differentiation
A family of techniques for computing exact derivatives of functions specified as programs, by applying the chain rule to elementary operations. Distinct from symbolic and numerical differentiation.
Related: JIT Compilation
JAX Fundamentals
# Code from: ch14/python/jax_basics.py
# Load from backend supplements endpoint