Implementing Proximal Operators
Why Proximal Operators?
Many modern optimization problems have objectives that are the sum of a smooth term and a non-smooth regularizer: . Gradient descent cannot handle the non-smooth directly. Proximal operators generalize the gradient step to non-smooth functions, enabling algorithms like ISTA, FISTA, and ADMM.
In signal processing, proximal operators implement denoising, sparsity promotion, and constraint projection as simple closed-form operations.
Definition: Proximal Operator
Proximal Operator
The proximal operator of a function with step size is:
Interpretation: find the point that balances minimizing against staying close to . When , the proximal operator is the identity. When is the indicator function of a set ( if , otherwise), the proximal operator is the projection onto .
The proximal operator always exists and is unique when is closed, proper, and convex.
Definition: Soft-Thresholding Operator
Soft-Thresholding Operator
The proximal operator of is the soft-thresholding (shrinkage) operator, applied element-wise:
def soft_threshold(v, lam):
"""Proximal operator of lam * ||x||_1."""
return np.sign(v) * np.maximum(np.abs(v) - lam, 0)
This shrinks each component toward zero by . Components with are set exactly to zero, producing sparsity β the mechanism behind LASSO and compressed sensing.
Definition: Projection as Proximal Operator
Projection as Proximal Operator
The proximal operator of the indicator function of a closed convex set is the Euclidean projection:
Common projections:
# Projection onto non-negative orthant
def proj_nonneg(v):
return np.maximum(v, 0)
# Projection onto L2 ball of radius r
def proj_l2_ball(v, r):
norm_v = np.linalg.norm(v)
return v * min(1, r / norm_v) if norm_v > 0 else v
# Projection onto simplex {x >= 0, sum(x) = 1}
def proj_simplex(v):
n = len(v)
u = np.sort(v)[::-1]
cssv = np.cumsum(u) - 1
rho = np.max(np.where(u > cssv / np.arange(1, n+1)))
theta = cssv[rho] / (rho + 1)
return np.maximum(v - theta, 0)
Definition: Group Soft-Thresholding
Group Soft-Thresholding
For the group LASSO penalty (mixed norm), the proximal operator applies block-wise shrinkage:
This sets entire groups to zero (group sparsity) rather than individual elements:
def group_soft_threshold(v, lam, groups):
"""Proximal of lam * sum_g ||x_g||_2."""
result = np.zeros_like(v)
for g in groups:
vg = v[g]
norm_vg = np.linalg.norm(vg)
if norm_vg > lam:
result[g] = vg * (1 - lam / norm_vg)
return result
Definition: Total Variation (TV) Proximal Operator
Total Variation (TV) Proximal Operator
The total variation of a 1D signal is , which promotes piecewise-constant solutions.
The proximal operator of TV has no simple closed form but can be computed in time using the taut-string algorithm or via iterative methods:
def prox_tv_1d(v, lam, n_iter=100):
"""Proximal of lam * TV(x) via dual (Chambolle)."""
n = len(v)
p = np.zeros(n - 1) # dual variable
for _ in range(n_iter):
# Forward difference of primal
x = v - lam * _adj_diff(p, n)
# Gradient step on dual
dx = np.diff(x)
p = p + dx / (2 * lam)
# Project onto [-1, 1]
p = np.clip(p, -1, 1)
return v - lam * _adj_diff(p, n)
def _adj_diff(p, n):
"""Adjoint of forward difference: -div."""
d = np.zeros(n)
d[0] = -p[0]
d[1:-1] = p[:-1] - p[1:]
d[-1] = p[-1]
return d
Theorem: ISTA (Iterative Shrinkage-Thresholding Algorithm)
For where is -smooth (i.e., is -Lipschitz) and is convex, the iteration:
converges to the global minimum with rate : .
FISTA (Fast ISTA) adds Nesterov momentum to achieve : , then apply the proximal step to .
ISTA splits the problem: gradient descent handles the smooth part, the proximal operator handles the non-smooth part. Each step takes a gradient step then "cleans up" with the proximal operator.
Example: ISTA for LASSO
Implement ISTA to solve the LASSO problem .
Implementation
import numpy as np
def ista_lasso(A, b, lam, n_iter=500):
m, n = A.shape
L = np.linalg.norm(A.T @ A, 2) # Lipschitz constant
x = np.zeros(n)
objective = []
for k in range(n_iter):
grad = A.T @ (A @ x - b)
x = soft_threshold(x - grad / L, lam / L)
obj = 0.5 * np.linalg.norm(A @ x - b)**2 + lam * np.linalg.norm(x, 1)
objective.append(obj)
return x, objective
def soft_threshold(v, lam):
return np.sign(v) * np.maximum(np.abs(v) - lam, 0)
FISTA acceleration
def fista_lasso(A, b, lam, n_iter=500):
m, n = A.shape
L = np.linalg.norm(A.T @ A, 2)
x = np.zeros(n)
x_prev = x.copy()
t = 1.0
objective = []
for k in range(n_iter):
y = x + (t - 1) / (t + 2) * (x - x_prev)
grad = A.T @ (A @ y - b)
x_prev = x.copy()
x = soft_threshold(y - grad / L, lam / L)
t = (1 + np.sqrt(1 + 4*t**2)) / 2
obj = 0.5*np.linalg.norm(A@x-b)**2 + lam*np.linalg.norm(x,1)
objective.append(obj)
return x, objective
Proximal Operator Visualization
Visualize how soft-thresholding, projection, and group soft-thresholding transform input signals.
Parameters
Proximal Operators and ISTA/FISTA
# Code from: ch08/python/proximal_operators.py
# Load from backend supplements endpointQuick Check
What does the soft-thresholding operator do to components with ?
Leaves them unchanged
Sets them to zero
Doubles them
Scales them by lambda
Components smaller than the threshold are set exactly to zero, producing sparsity.
Common Mistake: Wrong Step Size in ISTA
Mistake:
Using step size where is guessed instead of computed. If is too small (step too large), ISTA diverges.
Correction:
Compute (largest singular value
squared of ) for least-squares problems:
L = np.linalg.norm(A.T @ A, 2).
Alternatively, use backtracking line search to adaptively find .
Key Takeaway
Proximal operators are the building blocks of modern optimization. Soft-thresholding gives LASSO, projection gives constrained optimization, group soft-thresholding gives group LASSO, and TV proximal gives piecewise-constant denoising. Memorize these four operators and you can implement most first-order optimization algorithms.
Why This Matters: Compressed Sensing in Wireless Channel Estimation
Wireless channels are often sparse in the delay-Doppler domain: only a few paths carry significant energy. ISTA and FISTA with soft-thresholding (the LASSO proximal) recover these sparse channels from far fewer pilot symbols than traditional least-squares estimation requires. This is the basis of compressed sensing channel estimation in 5G mmWave systems.
See full treatment in Chapter 15
Proximal Operator
The mapping , generalizing the gradient step to non-smooth functions.
Related: Soft-Thresholding
Soft-Thresholding
The element-wise operator , which is the proximal operator of the norm.
Related: Proximal Operator