Implementing Differentiable Forward Models

Definition:

Differentiable Forward Model

A differentiable forward model implements the physics operator A:CnCm\mathcal{A}: \mathbb{C}^n \to \mathbb{C}^m in PyTorch so that gradients flow through it during training:

y=A(x)+n\mathbf{y} = \mathcal{A}(\mathbf{x}) + \mathbf{n}

Examples: FFT/IFFT, matrix multiplication Hx\mathbf{Hx}, convolution with a known PSF, or a full OFDM transmit chain.

class OFDMForward(nn.Module):
    def __init__(self, n_sub):
        super().__init__()
        self.n_sub = n_sub
    def forward(self, x_freq, h):
        # x_freq: transmitted symbols in frequency domain
        y_freq = h * x_freq  # channel in frequency domain
        y_time = torch.fft.ifft(y_freq, dim=-1)
        return y_time

PyTorch's FFT functions (torch.fft.fft, torch.fft.ifft) are fully differentiable. Most linear operators are trivially differentiable.

Example: Differentiable MIMO Channel

Implement a differentiable MIMO channel y=Hx+n\mathbf{y} = \mathbf{Hx} + \mathbf{n} for end-to-end training.

Example: End-to-End Autoencoder with Forward Model

Train a transmitter-receiver pair end-to-end through a differentiable channel model.