Implementing Proximal Operators

Why Proximal Operators?

Many modern optimization problems have objectives that are the sum of a smooth term and a non-smooth regularizer: min⁑xf(x)+g(x)\min_{\mathbf{x}} f(\mathbf{x}) + g(\mathbf{x}). Gradient descent cannot handle the non-smooth gg directly. Proximal operators generalize the gradient step to non-smooth functions, enabling algorithms like ISTA, FISTA, and ADMM.

In signal processing, proximal operators implement denoising, sparsity promotion, and constraint projection as simple closed-form operations.

Definition:

Proximal Operator

The proximal operator of a function g:Rnβ†’Rβˆͺ{+∞}g : \mathbb{R}^n \to \mathbb{R} \cup \{+\infty\} with step size Ξ»>0\lambda > 0 is:

proxΞ»g(v)=arg⁑min⁑x{g(x)+12Ξ»βˆ₯xβˆ’vβˆ₯22}\mathrm{prox}_{\lambda g}(\mathbf{v}) = \arg\min_{\mathbf{x}} \left\{ g(\mathbf{x}) + \frac{1}{2\lambda}\|\mathbf{x} - \mathbf{v}\|_2^2 \right\}

Interpretation: find the point that balances minimizing gg against staying close to v\mathbf{v}. When g=0g = 0, the proximal operator is the identity. When gg is the indicator function of a set CC (g(x)=0g(\mathbf{x}) = 0 if x∈C\mathbf{x} \in C, +∞+\infty otherwise), the proximal operator is the projection onto CC.

The proximal operator always exists and is unique when gg is closed, proper, and convex.

Definition:

Soft-Thresholding Operator

The proximal operator of g(x)=βˆ₯xβˆ₯1g(\mathbf{x}) = \|\mathbf{x}\|_1 is the soft-thresholding (shrinkage) operator, applied element-wise:

SΞ»(vi)=sign(vi)max⁑(∣viβˆ£βˆ’Ξ»,0)\mathcal{S}_\lambda(v_i) = \mathrm{sign}(v_i) \max(|v_i| - \lambda, 0)

def soft_threshold(v, lam):
    """Proximal operator of lam * ||x||_1."""
    return np.sign(v) * np.maximum(np.abs(v) - lam, 0)

This shrinks each component toward zero by Ξ»\lambda. Components with ∣viβˆ£β‰€Ξ»|v_i| \le \lambda are set exactly to zero, producing sparsity β€” the mechanism behind LASSO and compressed sensing.

Definition:

Projection as Proximal Operator

The proximal operator of the indicator function ΞΉC\iota_C of a closed convex set CC is the Euclidean projection:

proxΞΉC(v)=Ξ C(v)=arg⁑min⁑x∈Cβˆ₯xβˆ’vβˆ₯2\mathrm{prox}_{\iota_C}(\mathbf{v}) = \Pi_C(\mathbf{v}) = \arg\min_{\mathbf{x} \in C} \|\mathbf{x} - \mathbf{v}\|_2

Common projections:

# Projection onto non-negative orthant
def proj_nonneg(v):
    return np.maximum(v, 0)

# Projection onto L2 ball of radius r
def proj_l2_ball(v, r):
    norm_v = np.linalg.norm(v)
    return v * min(1, r / norm_v) if norm_v > 0 else v

# Projection onto simplex {x >= 0, sum(x) = 1}
def proj_simplex(v):
    n = len(v)
    u = np.sort(v)[::-1]
    cssv = np.cumsum(u) - 1
    rho = np.max(np.where(u > cssv / np.arange(1, n+1)))
    theta = cssv[rho] / (rho + 1)
    return np.maximum(v - theta, 0)

Definition:

Group Soft-Thresholding

For the group LASSO penalty βˆ‘gβˆ₯xgβˆ₯2\sum_g \|\mathbf{x}_g\|_2 (mixed β„“2,1\ell_{2,1} norm), the proximal operator applies block-wise shrinkage:

[proxΞ»g(v)]g=vgβ‹…max⁑ ⁣(1βˆ’Ξ»βˆ₯vgβˆ₯2,0)[\mathrm{prox}_{\lambda g}(\mathbf{v})]_g = \mathbf{v}_g \cdot \max\!\left(1 - \frac{\lambda}{\|\mathbf{v}_g\|_2}, 0\right)

This sets entire groups to zero (group sparsity) rather than individual elements:

def group_soft_threshold(v, lam, groups):
    """Proximal of lam * sum_g ||x_g||_2."""
    result = np.zeros_like(v)
    for g in groups:
        vg = v[g]
        norm_vg = np.linalg.norm(vg)
        if norm_vg > lam:
            result[g] = vg * (1 - lam / norm_vg)
    return result

Definition:

Total Variation (TV) Proximal Operator

The total variation of a 1D signal is TV(x)=βˆ‘i=1nβˆ’1∣xi+1βˆ’xi∣\mathrm{TV}(\mathbf{x}) = \sum_{i=1}^{n-1} |x_{i+1} - x_i|, which promotes piecewise-constant solutions.

The proximal operator of TV has no simple closed form but can be computed in O(n)O(n) time using the taut-string algorithm or via iterative methods:

def prox_tv_1d(v, lam, n_iter=100):
    """Proximal of lam * TV(x) via dual (Chambolle)."""
    n = len(v)
    p = np.zeros(n - 1)  # dual variable
    for _ in range(n_iter):
        # Forward difference of primal
        x = v - lam * _adj_diff(p, n)
        # Gradient step on dual
        dx = np.diff(x)
        p = p + dx / (2 * lam)
        # Project onto [-1, 1]
        p = np.clip(p, -1, 1)
    return v - lam * _adj_diff(p, n)

def _adj_diff(p, n):
    """Adjoint of forward difference: -div."""
    d = np.zeros(n)
    d[0] = -p[0]
    d[1:-1] = p[:-1] - p[1:]
    d[-1] = p[-1]
    return d

Theorem: ISTA (Iterative Shrinkage-Thresholding Algorithm)

For min⁑xf(x)+g(x)\min_{\mathbf{x}} f(\mathbf{x}) + g(\mathbf{x}) where ff is LL-smooth (i.e., βˆ‡f\nabla f is LL-Lipschitz) and gg is convex, the iteration:

xk+1=prox1Lg ⁣(xkβˆ’1Lβˆ‡f(xk))\mathbf{x}_{k+1} = \mathrm{prox}_{\frac{1}{L} g}\!\left( \mathbf{x}_k - \frac{1}{L}\nabla f(\mathbf{x}_k)\right)

converges to the global minimum with rate O(1/k)O(1/k): f(xk)+g(xk)βˆ’fβ‹†βˆ’g⋆≀Lβˆ₯x0βˆ’x⋆βˆ₯22kf(\mathbf{x}_k) + g(\mathbf{x}_k) - f^\star - g^\star \le \frac{L\|\mathbf{x}_0 - \mathbf{x}^\star\|^2}{2k}.

FISTA (Fast ISTA) adds Nesterov momentum to achieve O(1/k2)O(1/k^2): yk+1=xk+kβˆ’1k+2(xkβˆ’xkβˆ’1)\mathbf{y}_{k+1} = \mathbf{x}_k + \frac{k-1}{k+2}(\mathbf{x}_k - \mathbf{x}_{k-1}), then apply the proximal step to yk+1\mathbf{y}_{k+1}.

ISTA splits the problem: gradient descent handles the smooth part, the proximal operator handles the non-smooth part. Each step takes a gradient step then "cleans up" with the proximal operator.

Example: ISTA for LASSO

Implement ISTA to solve the LASSO problem min⁑x12βˆ₯Axβˆ’bβˆ₯22+Ξ»βˆ₯xβˆ₯1\min_{\mathbf{x}} \frac{1}{2}\|\mathbf{A}\mathbf{x} - \mathbf{b}\|_2^2 + \lambda\|\mathbf{x}\|_1.

Proximal Operator Visualization

Visualize how soft-thresholding, projection, and group soft-thresholding transform input signals.

Parameters

Proximal Operators and ISTA/FISTA

python
Implementations of soft-thresholding, projections, group soft-thresholding, TV proximal, ISTA, and FISTA.
# Code from: ch08/python/proximal_operators.py
# Load from backend supplements endpoint

Quick Check

What does the soft-thresholding operator Sλ(v)\mathcal{S}_\lambda(v) do to components with ∣vi∣<λ|v_i| < \lambda?

Leaves them unchanged

Sets them to zero

Doubles them

Scales them by lambda

Common Mistake: Wrong Step Size in ISTA

Mistake:

Using step size 1/L1/L where LL is guessed instead of computed. If LL is too small (step too large), ISTA diverges.

Correction:

Compute L=βˆ₯ATAβˆ₯2L = \|\mathbf{A}^T\mathbf{A}\|_2 (largest singular value squared of A\mathbf{A}) for least-squares problems: L = np.linalg.norm(A.T @ A, 2). Alternatively, use backtracking line search to adaptively find LL.

Key Takeaway

Proximal operators are the building blocks of modern optimization. Soft-thresholding gives LASSO, projection gives constrained optimization, group soft-thresholding gives group LASSO, and TV proximal gives piecewise-constant denoising. Memorize these four operators and you can implement most first-order optimization algorithms.

Why This Matters: Compressed Sensing in Wireless Channel Estimation

Wireless channels are often sparse in the delay-Doppler domain: only a few paths carry significant energy. ISTA and FISTA with soft-thresholding (the LASSO proximal) recover these sparse channels from far fewer pilot symbols than traditional least-squares estimation requires. This is the basis of compressed sensing channel estimation in 5G mmWave systems.

See full treatment in Chapter 15

Proximal Operator

The mapping proxΞ»g(v)=arg⁑min⁑xg(x)+12Ξ»βˆ₯xβˆ’vβˆ₯2\mathrm{prox}_{\lambda g}(\mathbf{v}) = \arg\min_{\mathbf{x}} g(\mathbf{x}) + \frac{1}{2\lambda}\|\mathbf{x} - \mathbf{v}\|^2, generalizing the gradient step to non-smooth functions.

Related: Soft-Thresholding

Soft-Thresholding

The element-wise operator SΞ»(vi)=sign(vi)max⁑(∣viβˆ£βˆ’Ξ»,0)\mathcal{S}_\lambda(v_i) = \mathrm{sign}(v_i)\max(|v_i| - \lambda, 0), which is the proximal operator of the β„“1\ell_1 norm.

Related: Proximal Operator