The U-Net Architecture β From Scratch
Definition: U-Net Architecture
U-Net Architecture
U-Net is an encoder-decoder with skip connections at each resolution:
- Encoder: repeated Conv-BN-ReLU + MaxPool (downsample)
- Bottleneck: processing at lowest resolution
- Decoder: ConvTranspose2d (upsample) + concatenation with encoder features
where denotes channel-wise concatenation.
class UNet(nn.Module):
def __init__(self, in_ch=1, out_ch=1, base_ch=64):
super().__init__()
self.enc1 = DoubleConv(in_ch, base_ch)
self.enc2 = DoubleConv(base_ch, base_ch*2)
self.enc3 = DoubleConv(base_ch*2, base_ch*4)
self.pool = nn.MaxPool2d(2)
self.bottleneck = DoubleConv(base_ch*4, base_ch*8)
self.up3 = nn.ConvTranspose2d(base_ch*8, base_ch*4, 2, stride=2)
self.dec3 = DoubleConv(base_ch*8, base_ch*4)
self.up2 = nn.ConvTranspose2d(base_ch*4, base_ch*2, 2, stride=2)
self.dec2 = DoubleConv(base_ch*4, base_ch*2)
self.up1 = nn.ConvTranspose2d(base_ch*2, base_ch, 2, stride=2)
self.dec1 = DoubleConv(base_ch*2, base_ch)
self.out_conv = nn.Conv2d(base_ch, out_ch, 1)
The skip connections preserve high-resolution spatial details that are lost during downsampling.
Example: U-Net Forward Pass
Implement the forward method showing the encoder-decoder data flow.
Forward method
def forward(self, x):
# Encoder
e1 = self.enc1(x) # (B, 64, H, W)
e2 = self.enc2(self.pool(e1)) # (B, 128, H/2, W/2)
e3 = self.enc3(self.pool(e2)) # (B, 256, H/4, W/4)
# Bottleneck
b = self.bottleneck(self.pool(e3)) # (B, 512, H/8, W/8)
# Decoder with skip connections
d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
return self.out_conv(d1)
Example: U-Net for Image Denoising
Train a U-Net to denoise images where the input is and the target is the clean image .
Training setup
model = UNet(in_ch=1, out_ch=1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
for epoch in range(50):
for clean in train_loader:
noisy = clean + 0.1 * torch.randn_like(clean)
optimizer.zero_grad()
denoised = model(noisy)
loss = loss_fn(denoised, clean)
loss.backward()
optimizer.step()
U-Net Architecture Visualiser
Explore U-Net with different depths and base channel widths.
Parameters
Why This Matters: U-Net for Range-Doppler Processing
The 2D range-Doppler map from radar processing is structurally similar to an image. U-Net architectures have been applied to denoise and enhance range-Doppler maps, leveraging multi-scale features to separate targets from clutter.
Skip Connection
A direct path that bypasses one or more layers, either by addition (ResNet) or concatenation (U-Net).
Encoder-Decoder
Architecture that compresses input to a low-dimensional representation (encoder) then expands it back (decoder).