Implementing Differentiable Forward Models
Definition: Differentiable Forward Model
Differentiable Forward Model
A differentiable forward model implements the physics operator in PyTorch so that gradients flow through it during training:
Examples: FFT/IFFT, matrix multiplication , convolution with a known PSF, or a full OFDM transmit chain.
class OFDMForward(nn.Module):
def __init__(self, n_sub):
super().__init__()
self.n_sub = n_sub
def forward(self, x_freq, h):
# x_freq: transmitted symbols in frequency domain
y_freq = h * x_freq # channel in frequency domain
y_time = torch.fft.ifft(y_freq, dim=-1)
return y_time
PyTorch's FFT functions (torch.fft.fft, torch.fft.ifft) are
fully differentiable. Most linear operators are trivially differentiable.
Example: Differentiable MIMO Channel
Implement a differentiable MIMO channel for end-to-end training.
Solution
Implementation
class MIMOChannel(nn.Module):
def __init__(self, n_tx, n_rx, noise_std=0.1):
super().__init__()
self.noise_std = noise_std
self.n_tx, self.n_rx = n_tx, n_rx
def forward(self, x, H):
y = torch.matmul(H, x.unsqueeze(-1)).squeeze(-1)
noise = self.noise_std * torch.randn_like(y)
return y + noise
Example: End-to-End Autoencoder with Forward Model
Train a transmitter-receiver pair end-to-end through a differentiable channel model.
Solution
Architecture
class CommAutoencoder(nn.Module):
def __init__(self, n_bits, n_symbols):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(n_bits, 64), nn.ReLU(),
nn.Linear(64, 2*n_symbols)) # I/Q
self.channel = MIMOChannel(n_symbols, n_symbols)
self.decoder = nn.Sequential(
nn.Linear(2*n_symbols, 64), nn.ReLU(),
nn.Linear(64, n_bits))
def forward(self, bits, H):
iq = self.encoder(bits.float())
x = torch.complex(iq[..., ::2], iq[..., 1::2])
y = self.channel(x, H)
y_real = torch.cat([y.real, y.imag], dim=-1)
return torch.sigmoid(self.decoder(y_real))