SURE-Based Training

Estimating MSE Without Ground Truth

Both DIP and Noise2Noise address the absence of clean targets, but from different angles: DIP avoids training altogether; Noise2Noise requires paired measurements. Stein's Unbiased Risk Estimate (SURE) offers a third path: it provides an unbiased estimate of the MSE E[fθ(y)x2]\mathbb{E}[\|f_\theta(\mathbf{y}) - \mathbf{x}\|^2] using only the noisy observation y\mathbf{y} and the denoiser's divergence.

The price of not having clean targets is a single extra term --- the divergence div(fθ)\operatorname{div}(f_\theta) --- which can be computed efficiently via a single vector-Jacobian product.

Definition:

Stein's Unbiased Risk Estimate (SURE)

SURE provides an unbiased estimate of the MSE of a denoiser fθ(y)f_\theta(\mathbf{y}) without access to the clean signal:

SURE(fθ)=1Nyfθ(y)2σ2+2σ2Ndiv(fθ)(y)\text{SURE}(f_\theta) = \frac{1}{N}\|\mathbf{y} - f_\theta(\mathbf{y})\|^2 - \sigma^2 + \frac{2\sigma^2}{N}\operatorname{div}(f_\theta)(\mathbf{y})

where div(fθ)=i=1N[fθ(y)]iyi\operatorname{div}(f_\theta) = \sum_{i=1}^N \frac{\partial [f_\theta(\mathbf{y})]_i}{\partial y_i} is the divergence of the denoiser, and σ2\sigma^2 is the noise variance.

Unbiasedness property: E[SURE(fθ)]=1NE[fθ(y)x2]\mathbb{E}[\text{SURE}(f_\theta)] = \frac{1}{N}\mathbb{E}[\|f_\theta(\mathbf{y}) - \mathbf{x}\|^2] for y=x+w\mathbf{y} = \mathbf{x} + \mathbf{w} with wN(0,σ2I)\mathbf{w} \sim \mathcal{N}(\mathbf{0}, \sigma^2\mathbf{I}).

SURE converts the unsupervised denoising problem into a supervised one: the SURE loss can be minimised via gradient descent, and the minimiser is the MMSE denoiser. The divergence term measures how much the denoiser "spreads" its output --- it is the price of not having clean targets.

,

Theorem: SURE Is an Unbiased Estimate of MSE

For y=x+w\mathbf{y} = \mathbf{x} + \mathbf{w} with wN(0,σ2I)\mathbf{w} \sim \mathcal{N}(\mathbf{0}, \sigma^2\mathbf{I}) and a weakly differentiable denoiser fθf_\theta:

Ew[SURE(fθ)]=1NEw[fθ(y)x2].\mathbb{E}_\mathbf{w}\bigl[\text{SURE}(f_\theta)\bigr] = \frac{1}{N}\mathbb{E}_\mathbf{w}\bigl[\|f_\theta(\mathbf{y}) - \mathbf{x}\|^2\bigr].

The SURE identity is a consequence of Stein's lemma: E[wig(w)]=σ2E[g/wi]\mathbb{E}[w_i \cdot g(\mathbf{w})] = \sigma^2 \mathbb{E}[\partial g / \partial w_i] for wiN(0,σ2)w_i \sim \mathcal{N}(0, \sigma^2). This connects the cross-term fθ(y)y,xy\langle f_\theta(\mathbf{y}) - \mathbf{y}, \mathbf{x} - \mathbf{y}\rangle (which depends on the unknown x\mathbf{x}) to the divergence (which depends only on the observable y\mathbf{y}).

Definition:

Monte Carlo Divergence Estimation

Computing div(fθ)=i[fθ(y)]i/yi\operatorname{div}(f_\theta) = \sum_i \partial [f_\theta(\mathbf{y})]_i / \partial y_i requires NN backpropagation passes (one per pixel), which is prohibitively expensive. The Monte Carlo estimator uses a single random probe vector:

div^(fθ)=bJfθb\widehat{\operatorname{div}}(f_\theta) = \mathbf{b}^\top \mathbf{J}_{f_\theta} \mathbf{b}

where bN(0,IN)\mathbf{b} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}_N) and Jfθb\mathbf{J}_{f_\theta} \mathbf{b} is computed via a single vector-Jacobian product (one backward pass).

Unbiasedness: Eb[div^]=div(fθ)\mathbb{E}_\mathbf{b}[\widehat{\operatorname{div}}] = \operatorname{div}(f_\theta).

The MC divergence adds exactly one extra backward pass per training sample. The variance can be reduced by averaging over multiple probe vectors, but in practice a single probe suffices.

SURE vs. True MSE During Training

Compare the SURE loss and the true MSE (computed with ground truth) during training. For a linear denoiser f(y)=αyf(\mathbf{y}) = \alpha\mathbf{y}, SURE is exact (no estimation variance). For nonlinear denoisers (soft thresholding, neural network), SURE tracks the true MSE with increasing variance.

Observe that the SURE minimum coincides with the MSE minimum, confirming unbiasedness. The divergence term increases with denoiser complexity (neural net > soft threshold > linear).

Parameters
20

Example: SURE for Linear Denoisers

Compute SURE in closed form for the linear denoiser fα(y)=αyf_\alpha(\mathbf{y}) = \alpha\,\mathbf{y} and find the optimal shrinkage α\alpha^*.

Theorem: Generalised SURE for Inverse Problems (GSURE)

For the inverse problem y=Ax+w\mathbf{y} = \mathbf{A}\mathbf{x} + \mathbf{w} with wN(0,σ2IM)\mathbf{w} \sim \mathcal{N}(\mathbf{0}, \sigma^2\mathbf{I}_M), the Generalised SURE for a reconstruction network fθ ⁣:RMRNf_\theta \colon \mathbb{R}^M \to \mathbb{R}^N is:

GSURE(fθ)=1MyAfθ(y)2σ2+2σ2Mdivy(Afθ)\text{GSURE}(f_\theta) = \frac{1}{M}\|\mathbf{y} - \mathbf{A}f_\theta(\mathbf{y})\|^2 - \sigma^2 + \frac{2\sigma^2}{M}\operatorname{div}_\mathbf{y}(\mathbf{A}f_\theta)

where divy(Afθ)=tr(AJfθ)\operatorname{div}_\mathbf{y}(\mathbf{A}f_\theta) = \operatorname{tr}(\mathbf{A}\,\mathbf{J}_{f_\theta}).

GSURE is unbiased for the projected MSE 1ME[Afθ(y)Ax2]\frac{1}{M}\mathbb{E}[\|\mathbf{A}f_\theta(\mathbf{y}) - \mathbf{A}\mathbf{x}\|^2], not the full reconstruction MSE.

GSURE constrains only the component of the reconstruction in the range of AH\mathbf{A}^H --- it says nothing about the null-space component. An additional regulariser (TV, DIP, equivariance) is needed for the null space.

,

SURE-Based Denoiser Training

Complexity: Each training step costs 2×\sim 2\times a standard supervised step (one extra backward pass for the divergence).
Input: Noisy images {yj}j=1J\{\mathbf{y}_j\}_{j=1}^J, noise variance σ2\sigma^2,
denoiser network fθf_\theta
Output: Trained denoiser fθf_{\theta^*}
1. Initialise θ\theta randomly
2. for epoch =1,,E= 1, \ldots, E do
3. \quad for mini-batch {yj}\{\mathbf{y}_j\} do
4. \quad\quad Forward pass: x^j=fθ(yj)\hat{\mathbf{x}}_j = f_\theta(\mathbf{y}_j)
5. \quad\quad Residual: rj=1Nyjx^j2r_j = \frac{1}{N}\|\mathbf{y}_j - \hat{\mathbf{x}}_j\|^2
6. \quad\quad MC divergence: sample bN(0,I)\mathbf{b} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}),
compute d^j=bJfθb\hat{d}_j = \mathbf{b}^\top \mathbf{J}_{f_\theta}\mathbf{b}
via vector-Jacobian product
7. \quad\quad SURE: Lj=rjσ2+2σ2Nd^j\mathcal{L}_j = r_j - \sigma^2 + \frac{2\sigma^2}{N}\hat{d}_j
8. \quad\quad θθηθ1batchjLj\theta \leftarrow \theta - \eta\,\nabla_\theta\,\frac{1}{|\text{batch}|}\sum_j \mathcal{L}_j
9. \quad end for
10. end for
11. return fθf_{\theta^*}

SURE training requires the noise variance σ2\sigma^2 to be known. If unknown, σ2\sigma^2 can be estimated from the measurements (e.g., median absolute deviation of wavelet coefficients).

,

Common Mistake: SURE Requires Gaussian Noise with Known Variance

Mistake:

Applying SURE-based training to RF imaging data with non-Gaussian noise (e.g., speckle, Poisson photon noise) or unknown noise level.

Correction:

Standard SURE assumes wN(0,σ2I)\mathbf{w} \sim \mathcal{N}(\mathbf{0}, \sigma^2\mathbf{I}) with known σ2\sigma^2. Violations cause biased risk estimates:

  • Non-Gaussian noise: Use Poisson-SURE or exponential-family SURE extensions (Eldar, 2008).
  • Unknown σ2\sigma^2: Estimate from data using robust methods (MAD of wavelet coefficients, or from measurement residuals).
  • Correlated noise: Use the generalised form with known covariance Σw\boldsymbol{\Sigma}_w.

For RF imaging, thermal noise is Gaussian but the effective noise after beamforming/matched filtering may be coloured.

🔧Engineering Note

Computational Cost of MC Divergence

The MC divergence estimator div^=bJfθb\widehat{\operatorname{div}} = \mathbf{b}^\top \mathbf{J}_{f_\theta}\mathbf{b} requires one vector-Jacobian product, which in PyTorch/JAX costs the same as one backward pass. This doubles the per-step training cost compared to supervised training.

Practical tips:

  • Use a single probe vector b\mathbf{b} per sample (variance is acceptable for SGD).
  • For large images (N>106N > 10^6), compute the divergence on random patches rather than the full image.
  • If using GSURE for inverse problems, the cost is O(Cfwd+Cbwd+CA)O(C_{\text{fwd}} + C_{\text{bwd}} + C_{\mathbf{A}}) per sample.

Quick Check

For a soft-thresholding denoiser fλ(y)i=sign(yi)max(yiλ,0)f_\lambda(\mathbf{y})_i = \text{sign}(y_i)\max(|y_i| - \lambda, 0), what is div(fλ)\operatorname{div}(f_\lambda)?

NN (the dimension)

{i:yi>λ}|\{i : |y_i| > \lambda\}| (number of non-zero components)

λN\lambda \cdot N

00 (soft thresholding is not differentiable)

Stein's Unbiased Risk Estimate (SURE)

A formula that provides an unbiased estimate of the MSE of a denoiser without access to the clean signal, using the denoiser's divergence as a correction term. Requires Gaussian noise with known variance.

Related: Generalised SURE (GSURE)

Generalised SURE (GSURE)

An extension of SURE to inverse problems that estimates the projected MSE Ax^Ax2\|\mathbf{A}\hat{\mathbf{x}} - \mathbf{A}\mathbf{x}\|^2 without clean ground truth, applicable to underdetermined systems.

Related: Stein's Unbiased Risk Estimate (SURE)

Key Takeaway

SURE estimates MSE without clean targets by adding a divergence correction to the residual. The divergence is computed efficiently via Monte Carlo estimation (one extra backward pass). SURE-trained denoisers match the quality of supervised training for Gaussian noise. GSURE extends this to inverse problems but is blind to the null space --- an additional regulariser is needed. The main limitation is the requirement for Gaussian noise with known σ2\sigma^2.