Complex-Valued Networks
Why Complex Networks Matter
Wireless signals are inherently complex-valued: baseband I/Q data, channel coefficients, and frequency-domain representations are all complex. While you can always split into real/imaginary channels, complex-valued networks preserve phase relationships and reduce parameter count by exploiting the algebraic structure of .
Definition: Complex-Valued Linear Layer
Complex-Valued Linear Layer
A complex linear layer computes where , .
In PyTorch (native complex support):
W = nn.Parameter(torch.randn(m, n, dtype=torch.cfloat))
b = nn.Parameter(torch.zeros(m, dtype=torch.cfloat))
z_out = z_in @ W.T + b
Equivalently, using real/imaginary split:
The structured (anti-symmetric) weight matrix preserves the algebraic properties of complex multiplication. A real 2-channel network uses twice the parameters without this structure.
Definition: Wirtinger Derivatives
Wirtinger Derivatives
For a function , the Wirtinger derivatives are:
PyTorch's autograd computes for real-valued loss functions, which is the correct gradient direction for optimisation.
Definition: Complex-Valued Convolution
Complex-Valued Convolution
For complex input and kernel :
class ComplexConv2d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, **kwargs):
super().__init__()
self.conv_re = nn.Conv2d(in_ch, out_ch, kernel_size, **kwargs)
self.conv_im = nn.Conv2d(in_ch, out_ch, kernel_size, **kwargs)
def forward(self, z):
re = self.conv_re(z.real) - self.conv_im(z.imag)
im = self.conv_re(z.imag) + self.conv_im(z.real)
return torch.complex(re, im)
Definition: Complex Batch Normalisation
Complex Batch Normalisation
Standard BatchNorm applies independently to real and imaginary parts. A proper complex BatchNorm whitens the 2D covariance:
where is the 2x2 covariance of real and imaginary parts.
Definition: Real/Imaginary Channel Stacking
Real/Imaginary Channel Stacking
The simplest approach to handle complex data: stack real and imaginary parts as separate channels:
z = torch.randn(B, C, H, W, dtype=torch.cfloat)
x = torch.cat([z.real, z.imag], dim=1) # (B, 2C, H, W)
# Process with standard real-valued Conv2d
y = model(x)
# Reconstruct complex output
z_out = torch.complex(y[:, :C], y[:, C:])
This approach works well in practice and is compatible with all standard layers, but does not enforce complex structure.
Theorem: Gradient Descent in Complex Domain
For a real-valued loss where , the steepest descent direction is:
This is the conjugate Wirtinger derivative, which PyTorch computes
automatically via loss.backward() on complex parameters.
The conjugate gradient is the correct analogue of the real gradient for optimising real-valued functions of complex variables.
Theorem: Parameter Reduction via Complex Structure
A complex linear layer has real parameters (the real and imaginary parts of ). An unconstrained real layer has parameters. The complex structure halves the parameter count.
Complex multiplication constrains the weight matrix to have anti-symmetric block structure, removing half the degrees of freedom.
Theorem: Phase Equivariance of Complex Networks
A complex linear layer satisfies phase equivariance:
for any global phase . This symmetry is natural for wireless signals where the absolute phase is arbitrary.
If you rotate all inputs by a common phase, the outputs rotate by the same amount. This is exactly the right behaviour for coherent signal processing.
Example: Complex-Valued MLP for Channel Estimation
Build a complex-valued MLP that maps received pilot signals to channel estimates .
Implementation
class ComplexMLP(nn.Module):
def __init__(self, n_pilot, n_channel, hidden=128):
super().__init__()
self.fc1_re = nn.Linear(n_pilot, hidden)
self.fc1_im = nn.Linear(n_pilot, hidden)
self.fc2_re = nn.Linear(hidden, n_channel)
self.fc2_im = nn.Linear(hidden, n_channel)
def complex_linear(self, z, fc_re, fc_im):
re = fc_re(z.real) - fc_im(z.imag)
im = fc_re(z.imag) + fc_im(z.real)
return torch.complex(re, im)
def forward(self, y_pilot):
h = self.complex_linear(y_pilot, self.fc1_re, self.fc1_im)
h = torch.complex(F.relu(h.real), F.relu(h.imag))
return self.complex_linear(h, self.fc2_re, self.fc2_im)
Example: Real/Imag Split vs Complex Network Comparison
Compare parameter count and MSE for real-split vs complex-structured networks on a channel estimation task.
Key result
With the same hidden width, the complex network has 2x fewer parameters and often achieves lower MSE because the complex structure acts as a regulariser.
Example: PyTorch Native Complex Support
Demonstrate PyTorch's native complex tensor support and autograd.
Usage
z = torch.randn(10, dtype=torch.cfloat, requires_grad=True)
w = torch.randn(10, dtype=torch.cfloat, requires_grad=True)
loss = (z * w).abs().sum()
loss.backward()
print(z.grad) # Wirtinger conjugate gradient
Complex Network Transformation Viewer
See how a complex linear layer transforms points in the complex plane.
Parameters
Parameter Count: Complex vs Real-Split
Compare parameter counts for complex-structured vs unconstrained real networks.
Parameters
Wirtinger Gradient Visualisation
Visualise the gradient direction for complex-valued optimisation.
Parameters
Phase Equivariance Demonstration
Rotate input phase and see output rotate identically.
Parameters
Complex vs Real Linear Layer
Complex-Valued NN Processing Pipeline
Quick Check
What gradient does PyTorch compute for complex parameters when calling .backward()?
The standard partial derivative df/dz
The Wirtinger conjugate derivative df/dz*
Gradients with respect to real and imaginary parts separately
For real-valued loss functions, df/dz* gives the steepest descent direction.
Quick Check
How does a complex linear layer's parameter count compare to an equivalent real layer?
Same number of parameters
Half the parameters
Double the parameters
Complex multiplication constrains the weight matrix, halving parameters.
Quick Check
What is the simplest way to process complex data with standard PyTorch layers?
Use only the magnitude (discard phase)
Stack real and imaginary parts as separate channels
Convert to polar coordinates
This gives (B, 2C, H, W) input compatible with standard Conv2d.
Common Mistake: Applying ReLU to Complex Tensors
Mistake:
Using F.relu(z) on a complex tensor, which only clips the real part.
Correction:
Use split activation: torch.complex(F.relu(z.real), F.relu(z.imag)),
or use modReLU: z * relu(|z| - b) / |z|, or use CReLU.
Common Mistake: Complex-Valued Loss Functions
Mistake:
Computing MSE directly on complex tensors: (z_pred - z_true).pow(2).mean()
gives a complex loss, which cannot be backpropagated.
Correction:
Use (z_pred - z_true).abs().pow(2).mean() or
F.mse_loss(z_pred.real, z_true.real) + F.mse_loss(z_pred.imag, z_true.imag).
Common Mistake: Standard BatchNorm on Complex Data
Mistake:
Applying nn.BatchNorm2d to stacked real/imag channels independently,
ignoring the correlation between real and imaginary parts.
Correction:
For most practical purposes, independent normalisation works fine. For strict correctness, implement 2x2 covariance whitening.
Key Takeaway
For wireless applications, start with real/imaginary channel stacking (simplest) and switch to complex-structured layers only if you need phase equivariance or parameter reduction.
Key Takeaway
PyTorch handles complex autograd correctly via Wirtinger calculus. You only need to ensure your loss function is real-valued.
Why This Matters: Complex Networks for MIMO Detection
MIMO detection involves recovering complex transmitted symbols from received signals . Complex-valued networks naturally handle the I/Q structure and can learn near-optimal detectors with far fewer parameters than real-split approaches.
Historical Note: Complex-Valued Neural Networks
1992-2020Complex-valued neural networks were studied theoretically in the 1990s (Hirose, 1992; Georgiou & Koutsougeras, 1992) but gained practical traction only after deep learning frameworks added complex tensor support. PyTorch added native complex autograd in v1.7 (2020).
Historical Note: Wirtinger Calculus
1927-1983Wilhelm Wirtinger introduced the calculus of complex variables in 1927. The application to optimisation of real-valued functions of complex variables was formalised by Brandwood (1983) and adopted by the signal processing community for adaptive filtering.
Wirtinger Derivative
Partial derivative with respect to or , enabling gradient-based optimisation in the complex domain.
Phase Equivariance
Property where a global phase rotation of the input produces the same rotation in the output.
CReLU
Complex ReLU that applies ReLU independently to real and imaginary parts.
modReLU
Complex activation: , preserving phase.
Differentiable Forward Model
A physics operator implemented in PyTorch so gradients can flow through it during training.
Complex Data Handling Approaches
| Approach | Implementation | Parameters | Phase Aware | Ease of Use |
|---|---|---|---|---|
| Real/Imag stack | torch.cat([z.real, z.imag], dim=1) | 4mn | No | Easiest |
| Complex structured | Constrained W with complex multiplication | 2mn | Yes | Medium |
| Native complex | torch.cfloat tensors + autograd | 2mn | Yes | Easy (PyTorch 1.7+) |