Learned Message Passing and Deep Unfolding

From Hand-Designed to Learned Message Passing

AMP, OAMP, VAMP, and GAMP all share a common template: a linear step, a nonlinear denoiser, and an Onsager-style correction. The parameters of each stage β€” thresholds, damping factors, denoiser shape, even the linear operator itself β€” are derived from a statistical model (sparsity prior, noise level, matrix spectrum). When the assumed model matches reality, performance is Bayes-optimal; when it does not, mismatch erodes the gains.

Deep unfolding turns this limitation into a feature. Each iteration of the hand-designed algorithm is reinterpreted as a layer in a neural network, the per-layer parameters are declared free, and the whole unrolled network is trained end-to-end on representative (x,\ntnobs)(\mathbf{x},\ntn{obs}) pairs. The result is an algorithm that retains the interpretability of message passing but adapts its coefficients to the empirical signal and matrix distribution β€” escaping the restrictive assumptions of analytical state evolution.

Definition:

LISTA β€” Learned ISTA

Given TT unrolled layers, LISTA replaces the fixed ISTA iteration xt+1=Ξ·st(xt+AH(\ntnobsβˆ’Axt);Ξ»)\mathbf{x}^{t+1} = \eta_{\text{st}}(\mathbf{x}^t + \mathbf{A}^{\mathsf{H}}(\ntn{obs}-\mathbf{A}\mathbf{x}^t);\lambda) with the learned recursion

xt+1=Ξ·st ⁣(Wt\ntnobs+Stxt;Β Ξ»t),t=0,…,Tβˆ’1,\mathbf{x}^{t+1} = \eta_{\text{st}}\!\left(\mathbf{W}_t \ntn{obs} + \mathbf{S}_t \mathbf{x}^t;\ \lambda_t\right), \qquad t=0,\dots,T-1,

where {Wt∈CNΓ—M,Β St∈CNΓ—N,Β Ξ»t∈R+}\{\mathbf{W}_t \in \mathbb{C}^{N\times M},\ \mathbf{S}_t \in \mathbb{C}^{N\times N},\ \lambda_t \in \mathbb{R}_+\} are learnable parameters, typically initialized from the ISTA values Wt=AH/L\mathbf{W}_t = \mathbf{A}^{\mathsf{H}}/L, St=Iβˆ’AHA/L\mathbf{S}_t = \mathbf{I} - \mathbf{A}^{\mathsf{H}}\mathbf{A}/L. Training minimizes the reconstruction loss Eβˆ₯xTβˆ’xβˆ₯2\mathbb{E}\|\mathbf{x}^{T} - \mathbf{x}\|^2 over a dataset of signal--measurement pairs.

Empirically, LISTA reaches ISTA's T=1000T=1000-iteration MSE in 10-20 learned layers β€” a 50Γ—\times speed-up. Weight tying (Wt≑W\mathbf{W}_t\equiv\mathbf{W}, St≑S\mathbf{S}_t\equiv\mathbf{S}) gives a lighter model that still beats vanilla ISTA.

Definition:

LAMP β€” Learned AMP

LAMP unrolls the AMP iteration. Each layer reads (xt,rt)(\mathbf{x}^t,\mathbf{r}^t) and produces

xt+1=Ξ·t ⁣(Btrt+xt;Β ΞΈt),rt+1=\ntnobsβˆ’Axt+1+btrt,\begin{aligned} \mathbf{x}^{t+1} &= \eta_t\!\left(\mathbf{B}_t\mathbf{r}^t + \mathbf{x}^t;\ \boldsymbol{\theta}_t\right), \\ \mathbf{r}^{t+1} &= \ntn{obs} - \mathbf{A}\mathbf{x}^{t+1} + b_t \mathbf{r}^t, \end{aligned}

where Bt\mathbf{B}_t is a learned feedback matrix (initialized at AH\mathbf{A}^{\mathsf{H}}), Ξ·t\eta_t is a parameterized denoiser (soft-threshold, scaled soft-threshold, or a small MLP), and bt∈Rb_t \in \mathbb{R} is the learned Onsager coefficient. All {Bt,ΞΈt,bt}t=0Tβˆ’1\{\mathbf{B}_t,\boldsymbol{\theta}_t,b_t\}_{t=0}^{T-1} are trained jointly by back-propagation.

LAMP preserves the Onsager-style correction btrtb_t\mathbf{r}^t β€” but instead of computing it analytically from Ξ΄βˆ’1βŸ¨Ξ·β€²βŸ©\delta^{-1}\langle\eta'\rangle, it learns the scalar btb_t directly. This is what makes LAMP robust to mismatched matrix ensembles: the network finds the right correction for the actual operator at hand.

Definition:

LDVAMP β€” Learned Denoising VAMP

LDVAMP unrolls the VAMP recursion, replacing each scalar state-evolution update with a learned function and each denoiser with a neural network. Per layer:

x^1t=g1(r1t;Β Ο•t(1)),Ξ±1t=learnedΒ divergence,(r2t,Ξ³2t)=Onsager(x^1t,r1t,Ξ±1t),x^2t=g2(r2t;Β Ο•t(2)),Ξ±2t=learnedΒ divergence,(r1t+1,Ξ³1t+1)=Onsager(x^2t,r2t,Ξ±2t).\begin{aligned} \hat{\mathbf{x}}_1^t &= \mathbf{g}_1(\mathbf{r}_1^t;\ \boldsymbol{\phi}_t^{(1)}), \quad \alpha_1^t = \text{learned divergence}, \\ (\mathbf{r}_2^t,\gamma_2^t) &= \text{Onsager}(\hat{\mathbf{x}}_1^t,\mathbf{r}_1^t,\alpha_1^t), \\ \hat{\mathbf{x}}_2^t &= \mathbf{g}_2(\mathbf{r}_2^t;\ \boldsymbol{\phi}_t^{(2)}), \quad \alpha_2^t = \text{learned divergence}, \\ (\mathbf{r}_1^{t+1},\gamma_1^{t+1}) &= \text{Onsager}(\hat{\mathbf{x}}_2^t,\mathbf{r}_2^t,\alpha_2^t). \end{aligned}

Here g2\mathbf{g}_2 is the LMMSE step (its parameters Ξ³2\gamma_2 learned rather than matched to A\mathbf{A}) and g1\mathbf{g}_1 is a learned prior denoiser (a CNN for images, a parameterized shrinkage for sparse signals).

LDVAMP inherits VAMP's robustness to ill-conditioned matrices while dropping the need to specify the signal prior or matrix spectrum analytically. It is the state-of-the-art unrolled network for structured compressed sensing.

Theorem: Linear Convergence Rate of LISTA

Assume A\mathbf{A} satisfies a restricted isometry property with constant Ξ΄2s<1/3\delta_{2s} < 1/3 and the signal is ss-sparse. Then the optimal LISTA parameters {Wt⋆,St⋆,Ξ»t⋆}\{\mathbf{W}_t^\star,\mathbf{S}_t^\star,\lambda_t^\star\} achieve the linear convergence rate

βˆ₯xTβˆ’xβˆ₯2≀c1β‹…qTΒ βˆ₯xβˆ₯2+c2Β \ntnnoisestd,\|\mathbf{x}^{T} - \mathbf{x}\|_2 \le c_1 \cdot q^{T}\ \|\mathbf{x}\|_2 + c_2\ \ntn{noisestd},

with contraction factor q=q(δ2s)∈(0,1)q = q(\delta_{2s}) \in (0,1) strictly smaller than the ISTA contraction qISTAq_{\text{ISTA}} at the same regularization.

LISTA beats ISTA's linear rate by adapting each layer's operator to the current sparsity pattern. Early layers can be aggressive (large thresholds) to commit to the strongest components; later layers can be gentle to refine small components. ISTA uses the same operator for every iteration, sacrificing this adaptivity.

Example: LAMP for Bernoulli-Gaussian Recovery

Design a 10-layer LAMP network for recovering Bernoulli-Gaussian signals (ρ=0.1\rho=0.1, unit-variance active components) from measurements \ntnobs=Ax+w\ntn{obs} = \mathbf{A}\mathbf{x} + \mathbf{w} with A∈RMΓ—N\mathbf{A} \in \mathbb{R}^{M\times N}, M/N=0.5M/N = 0.5, N=500N=500, SNR =20=20 dB. Describe the trainable parameters, loss, and expected behaviour.

⚠️Engineering Note

Training Tips for Unrolled Networks

Training unrolled message-passing networks is straightforward in principle but has a handful of recurring pitfalls:

  • Layer-wise greedy warm-start. Train layer 0, freeze it, then add and train layer 1, and so on. This avoids the vanishing- gradient problem that plagues end-to-end training of deep unrolled networks.
  • Tied vs. untied weights. Weight tying (Bt≑B\mathbf{B}_t \equiv \mathbf{B}) cuts parameters by TΓ—T\times and generalizes better when data is scarce; untied weights win on large datasets.
  • Matrix-specific vs. matrix-agnostic training. If A\mathbf{A} is fixed (e.g., a physical MRI encoder), train on that single matrix. If A\mathbf{A} varies per sample (e.g., random masks), train over the ensemble. The two regimes give different learned parameters.
  • Loss curriculum. Start with a soft loss (per-layer MSE averaged over tt) and anneal to the final-layer loss; this stabilizes early training.
  • Initialisation from analytical parameters. Always initialize B0\mathbf{B}_0 to AH\mathbf{A}^{\mathsf{H}} and b0b_0 to the analytical Onsager coefficient. Random init often fails to recover convergence even after training.

LISTA vs ISTA Convergence

Compare reconstruction MSE as a function of the number of (un)rolled iterations for ISTA with the optimal fixed step size and LISTA with learned per-layer parameters. The learned network reaches ISTA's asymptotic MSE in a small fraction of the layers.

Parameters
16
0.1
20

LAMP: MSE vs Layer Count

Visualize the final MSE achieved by LAMP as the number of unrolled layers grows. Compare against fixed-parameter AMP at the same iteration count. Notice how LAMP saturates faster and to a lower floor, especially for structured sensing matrices where AMP struggles.

Parameters
12
20

Common Mistake: Overfitting to a Single Sensing Matrix

Mistake:

Training an unrolled LAMP/LDVAMP network with a single realization of A\mathbf{A} drawn from the distribution of interest, and then deploying it on different realizations from the same distribution. The learned weights encode the idiosyncrasies of the training matrix and collapse on novel ones.

Correction:

Decide upfront whether the sensing matrix is fixed (e.g., a calibrated imaging system, a trained sparse code) or random per sample (e.g., random masks, fresh pilot realizations). In the fixed case, training with the single matrix is correct. In the random case, resample A\mathbf{A} every mini-batch during training so that the learned parameters generalize over the ensemble. Mismatch between training and deployment is a leading cause of disappointing unrolled-network results in practice.

LISTA

Learned ISTA β€” an unrolled neural network with the ISTA iteration as its layer template. Parameters (Wt,St,Ξ»t)(\mathbf{W}_t,\mathbf{S}_t,\lambda_t) are trained end-to-end to minimize reconstruction MSE. Achieves ISTA's asymptotic accuracy in far fewer layers than iterations.

Related: LAMP, Deep unfolding (algorithm unrolling)

LAMP

Learned AMP β€” an unrolled AMP iteration with learnable feedback matrix Bt\mathbf{B}_t, denoiser parameters, and Onsager scalar btb_t. Retains AMP's interpretability while adapting to empirical signal and matrix distributions.

Related: LISTA, Deep unfolding (algorithm unrolling)

Deep unfolding (algorithm unrolling)

A design paradigm that converts an iterative algorithm into a deep neural network by (i) unrolling a fixed number of iterations into layers, (ii) declaring per-layer parameters as trainable, and (iii) fitting them by end-to-end back-propagation. Combines the inductive bias of classical algorithms with the adaptivity of learned models.

Related: LISTA, LAMP

Quick Check

What is the principal advantage of LISTA over ISTA when both are run for TT iterations / layers?

LISTA has a strictly convex loss function, while ISTA does not.

LISTA learns per-layer parameters by end-to-end training, so it reaches low MSE in far fewer layers.

LISTA does not require a sparsity assumption, while ISTA does.

LISTA provably recovers the exact LASSO solution, while ISTA only approximates it.

Why This Matters: Unrolled VAMP for Wireless Channel Estimation

Pilot-based channel estimation in OFDM and massive-MIMO uplinks often reduces to a structured compressed-sensing problem: a sparse delay-Doppler-angular channel observed through a partial DFT or Kronecker dictionary. LDVAMP is a natural fit β€” the LMMSE step uses the known dictionary, while the learned prior denoiser captures dataset-specific channel statistics (clustering of multipath components, angular selectivity, Doppler coherence) that an analytical prior would miss.

The CommIT group has explored unrolled-VAMP pipelines for RF imaging and joint channel-activity estimation in unsourced random access, where the mix of known structure (sensing operator) and unknown data-driven priors (channel clusters) is exactly where unrolled networks outperform both hand-designed message passing and generic deep learning.

See full treatment in Chapter 27، Section sec-lista-imaging