Complex Tensors in PyTorch
Definition: Complex Tensor Types
Complex Tensor Types
PyTorch supports two complex dtypes:
| dtype | Real part | Imaginary part | Total bytes |
|---|---|---|---|
complex64 |
float32 | float32 | 8 |
complex128 |
float64 | float64 | 16 |
Creation:
z = torch.tensor([1+2j, 3+4j]) # complex128
z = torch.complex(torch.randn(3), torch.randn(3)) # complex64
z = torch.randn(3, dtype=torch.complex128) # random complex
Access real and imaginary parts via .real and .imag (these are views,
not copies).
Complex tensors store real and imaginary parts interleaved in memory, matching NumPy's convention and enabling zero-copy conversion.
Definition: Wirtinger Derivatives
Wirtinger Derivatives
For a function , the Wirtinger derivatives are:
where . PyTorch's autograd returns the conjugate Wirtinger derivative because this is the steepest-descent direction for real-valued loss functions of complex parameters.
This convention means z.grad gives the direction to subtract
for gradient descent, consistent with the real-valued case.
Theorem: Conjugate Wirtinger Derivative as Gradient
For a real-valued function , the steepest-descent direction is:
This is exactly what PyTorch stores in z.grad. The update rule
minimizes
just as in the real case.
Think of a complex parameter as two real parameters. The Wirtinger calculus packages the and gradients into a single complex number that points in the steepest-descent direction.
Example: Gradient of |z|^2 via Autograd
Verify that autograd correctly computes .
Implementation
import torch
z = torch.tensor([1+2j, 3-1j], dtype=torch.complex128,
requires_grad=True)
L = (z * z.conj()).real.sum() # |z|^2 = z * z*
L.backward()
print(f"z.grad = {z.grad}") # should equal z
print(f"z = {z}")
print(f"Match: {torch.allclose(z.grad, z)}")
Explanation
Since , the conjugate Wirtinger derivative
. PyTorch confirms
this with z.grad == z.
Example: FFT and Spectral Analysis with PyTorch
Compute the FFT of a signal
sampled at 1000 Hz using torch.fft, and find the dominant frequencies.
Implementation
import torch
fs = 1000.0
t = torch.arange(0, 1.0, 1.0/fs, dtype=torch.float64)
x = torch.cos(2 * torch.pi * 50 * t) + 0.5 * torch.sin(2 * torch.pi * 120 * t)
X = torch.fft.fft(x)
freqs = torch.fft.fftfreq(len(x), 1.0/fs)
# Power spectrum (one-sided)
mask = freqs >= 0
power = (X[mask].abs() ** 2) / len(x)
# Find peaks
top_k = torch.topk(power, 3)
print("Dominant frequencies:")
for idx in top_k.indices:
print(f" {freqs[mask][idx]:.0f} Hz, power={power[idx]:.2f}")
Key Points
torch.fft.fftreturns complex128 when input is float64.torch.fft.fftfreqmirrorsnumpy.fft.fftfreq.- The entire pipeline supports autograd, so you can differentiate through FFT operations.
Interactive FFT with Complex Tensors
Adjust signal frequencies, amplitudes, and sampling rate to see the FFT spectrum update in real time. Observe aliasing when the sampling rate is too low.
Parameters
Wirtinger Derivatives on the Complex Plane
Quick Check
When PyTorch computes z.grad for a real-valued loss
with complex parameter , what does it return?
The conjugate Wirtinger derivative gives the steepest-descent direction for real-valued losses.
Common Mistake: Modifying .real or .imag Without Caution
Mistake:
Assigning to .real or .imag of a complex tensor modifies the
original (they are views), which may break the computation graph:
z = torch.tensor([1+2j], requires_grad=True)
z.real[0] = 5.0 # RuntimeError!
Correction:
Build complex tensors from separate real and imaginary parts using
torch.complex(real_part, imag_part) to keep the graph intact.
Historical Note: Wilhelm Wirtinger and Complex Differentiation
1920sWilhelm Wirtinger (1865-1945) introduced the calculus of and in 1927. For decades it was a niche tool in several complex variables. Its revival in engineering came through adaptive filter theory (Brandwood 1983) and, more recently, through deep learning with complex-valued neural networks. PyTorch's adoption of Wirtinger calculus in 2020 (v1.7) made complex autograd practical.
Wirtinger Derivative
A pair of differential operators and that decompose the gradient of a function of complex variables into holomorphic and anti-holomorphic parts.
Hermitian Transpose
The conjugate transpose of a matrix, computed as
A.conj().T or A.mH in PyTorch.
Related: Wirtinger Derivative
Key Takeaway
PyTorch supports complex tensors natively with complex64 and
complex128. Autograd uses Wirtinger calculus, returning the
conjugate Wirtinger derivative which
is the steepest-descent direction for real-valued losses. FFT,
Hermitian transpose (.mH), and all standard operations work
seamlessly with complex tensors.
Complex Tensor Operations
# Code from: ch12/python/complex_tensors.py
# Load from backend supplements endpoint