JAX: Functional Numerical Computing

Why JAX?

JAX is Google's library for high-performance numerical computing. It provides a NumPy-compatible API with three superpowers: automatic differentiation (jax.grad), JIT compilation (jax.jit via XLA), and automatic vectorization (jax.vmap). Unlike Numba, JAX operates on pure functions and immutable arrays, following a functional programming paradigm.

This makes JAX ideal for gradient-based optimization, machine learning research, and scientific computing where composable transformations are valuable.

Definition:

JAX NumPy: Drop-in Replacement

jax.numpy mirrors the NumPy API but operates on immutable JAX arrays (jax.Array). Most NumPy code works by replacing import numpy as np with import jax.numpy as jnp:

import jax.numpy as jnp

x = jnp.linspace(0, 2 * jnp.pi, 1000)
y = jnp.sin(x) * jnp.exp(-0.1 * x)

Key difference: JAX arrays are immutable. In-place operations like x[0] = 5 raise an error. Use x = x.at[0].set(5) instead.

Immutability enables JAX to trace and transform functions safely. It guarantees that function transformations (grad, jit, vmap) produce correct results.

Definition:

XLA Compilation with jax.jit

jax.jit compiles a function using XLA (Accelerated Linear Algebra), Google's optimizing compiler for linear algebra:

import jax

@jax.jit
def selu(x, alpha=1.67, lam=1.05):
    return lam * jnp.where(x > 0, x, alpha * (jnp.exp(x) - 1))

XLA fuses operations, eliminates temporaries, and optimizes memory layout. Unlike Numba, jax.jit compiles entire computational graphs, not individual loops.

The first call triggers tracing and compilation. JAX traces the function with abstract values to build a computation graph (jaxpr), then compiles it via XLA.

Definition:

Automatic Differentiation with jax.grad

jax.grad computes exact gradients of scalar-valued functions via reverse-mode automatic differentiation (backpropagation):

def loss(w, x, y):
    pred = jnp.dot(x, w)
    return jnp.mean((pred - y) ** 2)

grad_loss = jax.grad(loss)    # d(loss)/d(w)
w_grad = grad_loss(w, x, y)   # same shape as w

Higher-order derivatives compose naturally: jax.grad(jax.grad(f)) gives the second derivative. jax.jacobian and jax.hessian handle vector-valued functions.

Automatic differentiation is not finite differences (which has truncation error) and not symbolic differentiation (which produces expression swell). AD computes exact derivatives at machine precision.

Definition:

Automatic Vectorization with jax.vmap

jax.vmap transforms a function that operates on single examples into one that operates on batches, without manual broadcasting:

def predict_single(w, x):
    return jnp.dot(w, x)

# Vectorize over batch of x, keep w fixed
predict_batch = jax.vmap(predict_single, in_axes=(None, 0))
predictions = predict_batch(w, X_batch)  # (batch, features) -> (batch,)

in_axes specifies which argument axis to map over (None for broadcast, 0 for first axis, etc.).

Definition:

Functional Purity in JAX

JAX transformations require pure functions: given the same inputs, the function must always produce the same outputs with no side effects.

Violations include:

  • Global state or mutable closures
  • In-place array mutation (x[i] = v)
  • Python print() inside @jax.jit (runs only during tracing)
  • Non-deterministic random numbers (use jax.random with explicit keys)
# Wrong: stateful RNG
def bad(): return jnp.array(np.random.randn(3))

# Right: explicit key
def good(key): return jax.random.normal(key, shape=(3,))

JAX's random number generation uses a split-key model. Each random call consumes a key: key, subkey = jax.random.split(key).

Theorem: Reverse-Mode AD Complexity

For a function f:Rnβ†’Rf: \mathbb{R}^n \to \mathbb{R} composed of mm elementary operations, reverse-mode automatic differentiation computes the full gradient βˆ‡f∈Rn\nabla f \in \mathbb{R}^n in:

Cost(βˆ‡f)≀4β‹…Cost(f)\text{Cost}(\nabla f) \le 4 \cdot \text{Cost}(f)

independent of nn. This is the cheap gradient principle: the gradient costs at most a constant factor more than the function evaluation itself.

Forward mode propagates derivatives alongside computation (O(n)O(n) cost for nn inputs). Reverse mode propagates backward from the output (O(1)O(1) cost regardless of nn). This is why backpropagation in neural networks scales to millions of parameters.

Example: Gradient Descent Optimization with JAX

Minimize the Rosenbrock function f(x,y)=(1βˆ’x)2+100(yβˆ’x2)2f(x, y) = (1-x)^2 + 100(y-x^2)^2 using gradient descent with jax.grad.

Example: Batched Matrix-Vector Products with vmap

Given a batch of matrices A1,…,AB∈RmΓ—nA_1, \ldots, A_B \in \mathbb{R}^{m \times n} and corresponding vectors x1,…,xB∈Rnx_1, \ldots, x_B \in \mathbb{R}^n, compute all products AixiA_i x_i without explicit loops.

Automatic Differentiation Explorer

Visualize a function and its JAX-computed gradient. Compare with finite-difference approximation to see AD accuracy.

Parameters

JAX vs. PyTorch

FeatureJAXPyTorch
ParadigmFunctional (pure functions)Object-oriented (nn.Module)
Differentiationjax.grad (function transform)torch.autograd (tape-based)
JIT Compilationjax.jit via XLAtorch.compile via TorchDynamo
Auto-vectorizationjax.vmap (built-in)torch.vmap (since 2.0)
GPU SupportTPU + GPU + CPUGPU + CPU
MutabilityImmutable arraysMutable tensors
Random NumbersExplicit key splittingGlobal state (torch.manual_seed)
EcosystemFlax, Optax, Haikutorchvision, torchaudio, HuggingFace
Best ForResearch, custom algorithmsProduction ML, rapid prototyping

Quick Check

Why are JAX arrays immutable?

To save memory

To ensure function transformations (grad, jit, vmap) are correct

Because Python lists are immutable

To be compatible with TensorFlow

Common Mistake: Side Effects Inside jax.jit

Mistake:

Putting print() or other side effects inside a @jax.jit function and expecting them to execute on every call:

@jax.jit
def f(x):
    print("called!")  # Only prints during tracing, not execution
    return x ** 2

Correction:

Use jax.debug.print() for debugging inside JIT. For logging, use jax.debug.callback(). Remember: @jax.jit traces the function once and compiles it; Python side effects only run during tracing.

Historical Note: From Autograd to JAX

2015-2018

JAX evolved from Autograd, a Python library for automatic differentiation created by Dougal Maclaurin, David Duvenaud, and Matt Johnson at Harvard in 2015. Autograd could differentiate native Python and NumPy code but lacked JIT compilation. In 2018, the same team at Google Brain combined Autograd's tracing approach with XLA compilation to create JAX, achieving both expressiveness and performance.

Why This Matters: JAX for Channel Estimation

JAX's jax.grad and jax.jit are ideal for iterative channel estimation algorithms. For example, maximum-likelihood channel estimation requires minimizing βˆ₯yβˆ’Hxβˆ₯2\|y - Hx\|^2 over HH. With JAX, you write the loss function in NumPy-like syntax, and jax.grad provides the exact gradient for gradient descent or L-BFGS, while jax.jit compiles the entire optimization loop for GPU execution.

See full treatment in Chapter 16

Key Takeaway

JAX is a functional numerical computing library: write pure functions with jax.numpy, then compose transformations (jit for speed, grad for derivatives, vmap for batching). Its strength is composability β€” jax.jit(jax.vmap(jax.grad(f))) gives you a compiled, batched gradient function in one line.

XLA

Accelerated Linear Algebra: Google's domain-specific compiler for linear algebra that fuses operations, optimizes memory, and targets CPUs, GPUs, and TPUs.

Related: JIT Compilation

Automatic Differentiation

A family of techniques for computing exact derivatives of functions specified as programs, by applying the chain rule to elementary operations. Distinct from symbolic and numerical differentiation.

Related: JIT Compilation

JAX Fundamentals

python
Working examples of jax.numpy, jax.jit, jax.grad, jax.vmap, and the JAX random number system.
# Code from: ch14/python/jax_basics.py
# Load from backend supplements endpoint