Vision Transformer (ViT)
Definition: Vision Transformer (ViT)
Vision Transformer (ViT)
ViT applies the transformer to images by:
- Split image into patches
- Flatten and linearly project each patch to
- Prepend a [CLS] token, add positional embeddings
- Process through transformer encoder blocks
- Classify using the [CLS] token output
class ViT(nn.Module):
def __init__(self, img_size=224, patch_size=16, d_model=768,
n_heads=12, n_layers=12, n_classes=1000):
super().__init__()
n_patches = (img_size // patch_size) ** 2
self.patch_embed = nn.Conv2d(3, d_model, patch_size, stride=patch_size)
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
self.pos_embed = nn.Parameter(torch.randn(1, n_patches + 1, d_model))
self.blocks = nn.Sequential(*[
TransformerBlock(d_model, n_heads, d_model * 4)
for _ in range(n_layers)])
self.head = nn.Linear(d_model, n_classes)
Example: ViT Forward Pass
Implement the ViT forward method.
Solution
Implementation
def forward(self, img):
B = img.shape[0]
x = self.patch_embed(img).flatten(2).transpose(1, 2)
cls = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls, x], dim=1) + self.pos_embed
x = self.blocks(x)
return self.head(x[:, 0])