Complex-Valued Networks

Why Complex Networks Matter

Wireless signals are inherently complex-valued: baseband I/Q data, channel coefficients, and frequency-domain representations are all complex. While you can always split into real/imaginary channels, complex-valued networks preserve phase relationships and reduce parameter count by exploiting the algebraic structure of C\mathbb{C}.

Definition:

Complex-Valued Linear Layer

A complex linear layer computes zout=Wzin+b\mathbf{z}_{\text{out}} = \mathbf{W}\mathbf{z}_{\text{in}} + \mathbf{b} where W∈CmΓ—n\mathbf{W} \in \mathbb{C}^{m \times n}, b∈Cm\mathbf{b} \in \mathbb{C}^m.

In PyTorch (native complex support):

W = nn.Parameter(torch.randn(m, n, dtype=torch.cfloat))
b = nn.Parameter(torch.zeros(m, dtype=torch.cfloat))
z_out = z_in @ W.T + b

Equivalently, using real/imaginary split: [β„œ(zout)β„‘(zout)]=[β„œ(W)βˆ’β„‘(W)β„‘(W)β„œ(W)][β„œ(zin)β„‘(zin)]\begin{bmatrix} \Re(\mathbf{z}_{\text{out}}) \\ \Im(\mathbf{z}_{\text{out}}) \end{bmatrix} = \begin{bmatrix} \Re(\mathbf{W}) & -\Im(\mathbf{W}) \\ \Im(\mathbf{W}) & \Re(\mathbf{W}) \end{bmatrix} \begin{bmatrix} \Re(\mathbf{z}_{\text{in}}) \\ \Im(\mathbf{z}_{\text{in}}) \end{bmatrix}

The structured (anti-symmetric) weight matrix preserves the algebraic properties of complex multiplication. A real 2-channel network uses twice the parameters without this structure.

Definition:

Wirtinger Derivatives

For a function f:C→Rf: \mathbb{C} \to \mathbb{R}, the Wirtinger derivatives are:

βˆ‚fβˆ‚z=12(βˆ‚fβˆ‚xβˆ’jβˆ‚fβˆ‚y),βˆ‚fβˆ‚zβˆ—=12(βˆ‚fβˆ‚x+jβˆ‚fβˆ‚y)\frac{\partial f}{\partial z} = \frac{1}{2}\left(\frac{\partial f}{\partial x} - j\frac{\partial f}{\partial y}\right), \qquad \frac{\partial f}{\partial z^*} = \frac{1}{2}\left(\frac{\partial f}{\partial x} + j\frac{\partial f}{\partial y}\right)

PyTorch's autograd computes βˆ‚f/βˆ‚zβˆ—\partial f / \partial z^* for real-valued loss functions, which is the correct gradient direction for optimisation.

Definition:

Complex-Valued Convolution

For complex input Z\mathbf{Z} and kernel K\mathbf{K}:

Zβˆ—K=(β„œ(Z)βˆ—β„œ(K)βˆ’β„‘(Z)βˆ—β„‘(K))+j(β„œ(Z)βˆ—β„‘(K)+β„‘(Z)βˆ—β„œ(K))\mathbf{Z} * \mathbf{K} = (\Re(\mathbf{Z}) * \Re(\mathbf{K}) - \Im(\mathbf{Z}) * \Im(\mathbf{K})) + j(\Re(\mathbf{Z}) * \Im(\mathbf{K}) + \Im(\mathbf{Z}) * \Re(\mathbf{K}))

class ComplexConv2d(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, **kwargs):
        super().__init__()
        self.conv_re = nn.Conv2d(in_ch, out_ch, kernel_size, **kwargs)
        self.conv_im = nn.Conv2d(in_ch, out_ch, kernel_size, **kwargs)
    def forward(self, z):
        re = self.conv_re(z.real) - self.conv_im(z.imag)
        im = self.conv_re(z.imag) + self.conv_im(z.real)
        return torch.complex(re, im)

Definition:

Complex Batch Normalisation

Standard BatchNorm applies independently to real and imaginary parts. A proper complex BatchNorm whitens the 2D covariance:

z^=Vβˆ’1/2(zβˆ’ΞΌ)\hat{\mathbf{z}} = \mathbf{V}^{-1/2}(\mathbf{z} - \boldsymbol{\mu})

where V=[VrrVriVriVii]\mathbf{V} = \begin{bmatrix} V_{rr} & V_{ri} \\ V_{ri} & V_{ii} \end{bmatrix} is the 2x2 covariance of real and imaginary parts.

Definition:

Real/Imaginary Channel Stacking

The simplest approach to handle complex data: stack real and imaginary parts as separate channels:

z = torch.randn(B, C, H, W, dtype=torch.cfloat)
x = torch.cat([z.real, z.imag], dim=1)  # (B, 2C, H, W)
# Process with standard real-valued Conv2d
y = model(x)
# Reconstruct complex output
z_out = torch.complex(y[:, :C], y[:, C:])

This approach works well in practice and is compatible with all standard layers, but does not enforce complex structure.

Theorem: Gradient Descent in Complex Domain

For a real-valued loss L(z)L(\mathbf{z}) where z∈Cn\mathbf{z} \in \mathbb{C}^n, the steepest descent direction is:

zt+1=ztβˆ’Ξ·(βˆ‚Lβˆ‚zβˆ—)\mathbf{z}_{t+1} = \mathbf{z}_t - \eta \left(\frac{\partial L}{\partial \mathbf{z}^*}\right)

This is the conjugate Wirtinger derivative, which PyTorch computes automatically via loss.backward() on complex parameters.

The conjugate gradient βˆ‚L/βˆ‚zβˆ—\partial L / \partial z^* is the correct analogue of the real gradient for optimising real-valued functions of complex variables.

Theorem: Parameter Reduction via Complex Structure

A complex linear layer Cn→Cm\mathbb{C}^n \to \mathbb{C}^m has 2mn2mn real parameters (the real and imaginary parts of W\mathbf{W}). An unconstrained real layer R2n→R2m\mathbb{R}^{2n} \to \mathbb{R}^{2m} has 4mn4mn parameters. The complex structure halves the parameter count.

Complex multiplication constrains the weight matrix to have anti-symmetric block structure, removing half the degrees of freedom.

Theorem: Phase Equivariance of Complex Networks

A complex linear layer f(z)=Wzf(\mathbf{z}) = \mathbf{W}\mathbf{z} satisfies phase equivariance:

f(ejΟ•z)=ejΟ•f(z)f(e^{j\phi}\mathbf{z}) = e^{j\phi}f(\mathbf{z})

for any global phase Ο•\phi. This symmetry is natural for wireless signals where the absolute phase is arbitrary.

If you rotate all inputs by a common phase, the outputs rotate by the same amount. This is exactly the right behaviour for coherent signal processing.

Example: Complex-Valued MLP for Channel Estimation

Build a complex-valued MLP that maps received pilot signals yp∈CNp\mathbf{y}_p \in \mathbb{C}^{N_p} to channel estimates h^∈CN\hat{\mathbf{h}} \in \mathbb{C}^{N}.

Example: Real/Imag Split vs Complex Network Comparison

Compare parameter count and MSE for real-split vs complex-structured networks on a channel estimation task.

Example: PyTorch Native Complex Support

Demonstrate PyTorch's native complex tensor support and autograd.

Complex Network Transformation Viewer

See how a complex linear layer transforms points in the complex plane.

Parameters

Parameter Count: Complex vs Real-Split

Compare parameter counts for complex-structured vs unconstrained real networks.

Parameters

Wirtinger Gradient Visualisation

Visualise the gradient direction for complex-valued optimisation.

Parameters

Phase Equivariance Demonstration

Rotate input phase and see output rotate identically.

Parameters

Complex vs Real Linear Layer

Complex vs Real Linear Layer
A complex linear layer has structured (anti-symmetric) weight matrix, using half the parameters of an unconstrained real layer.

Complex-Valued NN Processing Pipeline

Complex-Valued NN Processing Pipeline
End-to-end processing: I/Q input, complex Conv2d layers, complex activation, and loss computation with Wirtinger gradients.

Quick Check

What gradient does PyTorch compute for complex parameters when calling .backward()?

The standard partial derivative df/dz

The Wirtinger conjugate derivative df/dz*

Gradients with respect to real and imaginary parts separately

Quick Check

How does a complex linear layer's parameter count compare to an equivalent real layer?

Same number of parameters

Half the parameters

Double the parameters

Quick Check

What is the simplest way to process complex data with standard PyTorch layers?

Use only the magnitude (discard phase)

Stack real and imaginary parts as separate channels

Convert to polar coordinates

Common Mistake: Applying ReLU to Complex Tensors

Mistake:

Using F.relu(z) on a complex tensor, which only clips the real part.

Correction:

Use split activation: torch.complex(F.relu(z.real), F.relu(z.imag)), or use modReLU: z * relu(|z| - b) / |z|, or use CReLU.

Common Mistake: Complex-Valued Loss Functions

Mistake:

Computing MSE directly on complex tensors: (z_pred - z_true).pow(2).mean() gives a complex loss, which cannot be backpropagated.

Correction:

Use (z_pred - z_true).abs().pow(2).mean() or F.mse_loss(z_pred.real, z_true.real) + F.mse_loss(z_pred.imag, z_true.imag).

Common Mistake: Standard BatchNorm on Complex Data

Mistake:

Applying nn.BatchNorm2d to stacked real/imag channels independently, ignoring the correlation between real and imaginary parts.

Correction:

For most practical purposes, independent normalisation works fine. For strict correctness, implement 2x2 covariance whitening.

Key Takeaway

For wireless applications, start with real/imaginary channel stacking (simplest) and switch to complex-structured layers only if you need phase equivariance or parameter reduction.

Key Takeaway

PyTorch handles complex autograd correctly via Wirtinger calculus. You only need to ensure your loss function is real-valued.

Why This Matters: Complex Networks for MIMO Detection

MIMO detection involves recovering complex transmitted symbols x∈CNt\mathbf{x} \in \mathbb{C}^{N_t} from received signals y=Hx+n\mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n}. Complex-valued networks naturally handle the I/Q structure and can learn near-optimal detectors with far fewer parameters than real-split approaches.

Historical Note: Complex-Valued Neural Networks

1992-2020

Complex-valued neural networks were studied theoretically in the 1990s (Hirose, 1992; Georgiou & Koutsougeras, 1992) but gained practical traction only after deep learning frameworks added complex tensor support. PyTorch added native complex autograd in v1.7 (2020).

Historical Note: Wirtinger Calculus

1927-1983

Wilhelm Wirtinger introduced the calculus of complex variables in 1927. The application to optimisation of real-valued functions of complex variables was formalised by Brandwood (1983) and adopted by the signal processing community for adaptive filtering.

Wirtinger Derivative

Partial derivative with respect to zz or zβˆ—z^*, enabling gradient-based optimisation in the complex domain.

Phase Equivariance

Property where a global phase rotation of the input produces the same rotation in the output.

CReLU

Complex ReLU that applies ReLU independently to real and imaginary parts.

modReLU

Complex activation: modReLU(z)=zβ‹…ReLU(∣zβˆ£βˆ’b)/∣z∣\text{modReLU}(z) = z \cdot \text{ReLU}(|z| - b) / |z|, preserving phase.

Differentiable Forward Model

A physics operator implemented in PyTorch so gradients can flow through it during training.

Complex Data Handling Approaches

ApproachImplementationParametersPhase AwareEase of Use
Real/Imag stacktorch.cat([z.real, z.imag], dim=1)4mnNoEasiest
Complex structuredConstrained W with complex multiplication2mnYesMedium
Native complextorch.cfloat tensors + autograd2mnYesEasy (PyTorch 1.7+)