Vision Transformer (ViT)

Definition:

Vision Transformer (ViT)

ViT applies the transformer to images by:

  1. Split image into P×PP \times P patches
  2. Flatten and linearly project each patch to dmodeld_{\text{model}}
  3. Prepend a [CLS] token, add positional embeddings
  4. Process through transformer encoder blocks
  5. Classify using the [CLS] token output
class ViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, d_model=768,
                 n_heads=12, n_layers=12, n_classes=1000):
        super().__init__()
        n_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, d_model, patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.pos_embed = nn.Parameter(torch.randn(1, n_patches + 1, d_model))
        self.blocks = nn.Sequential(*[
            TransformerBlock(d_model, n_heads, d_model * 4)
            for _ in range(n_layers)])
        self.head = nn.Linear(d_model, n_classes)

Example: ViT Forward Pass

Implement the ViT forward method.