VisionTransformer is a small PyTorch project for image classification with a Vision Transformer (ViT). The repository began as a from-scratch MNIST tutorial and now includes a configurable training script, dataset presets for MNIST, FashionMNIST, and CIFAR-10, and a modular implementation split across models/ and data/.
The current source code is the authoritative reference for how this repo works today. The linked Colab notebook and article are still useful background material, but they describe an earlier, MNIST-only tutorial version of the project.
- A ViT implementation built from patch embeddings, positional encodings, transformer encoder blocks, and a classification head.
- Dataset-specific defaults in
data/configs.pyfor MNIST, FashionMNIST, and CIFAR-10. - A CLI training entrypoint in
training.pywith overrides for model width, patch size, augmentation, optimization, checkpoint path, and more. - Optional train/validation splits, cosine learning rate scheduling, warmup, and AdamW weight decay.
- A Colab notebook in
notebooks/VisionTransformer.ipynbthat mirrors the original educational walkthrough.
training.py: CLI entrypoint for training, checkpointing, and test-set evaluation.models/model.py: Patch embedding, positional encoding, multi-head attention, encoder block, andVisionTransformer.data/configs.py: Default hyperparameters for each supported dataset.data/data_utils.py: Config selection, dataset loading, normalization, splitting, and dataloaders.data/datasets.py: Thin dataset wrapper used after random train/validation splits.notebooks/VisionTransformer.ipynb: Tutorial notebook for the original MNIST-only implementation.docs/architecture.md: How the current code flows end to end.docs/training-and-configuration.md: Setup, CLI reference, defaults, and example commands.docs/tutorial-drift.md: Differences between the tutorial materials and the current repo.
Run commands from the repository root.
python3 -m venv .venv
source .venv/bin/activate
python -m pip install --upgrade pipThe pinned requirements.txt expects CUDA 12.1 PyTorch wheels:
pip install --extra-index-url https://download.pytorch.org/whl/cu121 -r requirements.txtIf you are on CPU-only hardware or a different CUDA version, install matching torch and torchvision wheels first, then install the remaining packages:
pip install --index-url https://download.pytorch.org/whl/cpu torch==2.2.0 torchvision==0.17.0
pip install datasets==2.17.1 numpy==2.0.0MNIST with default settings:
python3 training.py --dataset mnistFashionMNIST with a validation split and validation accuracy reporting:
python3 training.py --dataset fashion_mnist --train_val_split 55000 5000 --get_val_accuracy TrueCIFAR-10 with an explicit checkpoint path:
python3 training.py --dataset cifar10 --model_location checkpoints/cifar10_vit.ptTraining downloads the training dataset if needed, saves the best checkpoint seen during training, and then evaluates that checkpoint on the test set.
- Device selection is automatic: CUDA if available, otherwise CPU.
- Dataset defaults come from
data/configs.pyand are overridden by CLI arguments. - Training data is downloaded into
../datasetsrelative to the current working directory. - The default checkpoint path is
model.ptin the current working directory. - Validation is only used if
train_val_splitleaves examples for a validation set.
- Architecture walkthrough
- Training and configuration reference
- Tutorial drift and implementation notes
- Colab notebook: the original step-by-step MNIST tutorial.
- Medium article: the April 4, 2024 tutorial narrative by Matt Nguyen.
- Correll Lab mirror: an accessible mirror of the article.
Use those resources to understand the original teaching flow. Use this repository's Python modules and the docs in docs/ to understand the current implementation.
- The tutorial materials and notebook are older than the current training pipeline and do not document later additions like validation splits, CIFAR-10 support, warmup, cosine scheduling, or AdamW.
requirements.txtpins CUDA-specific PyTorch builds, so installation instructions must be adjusted for CPU-only systems or different CUDA versions.- The repository does not include an automated test suite or published benchmark table, so training behavior should still be validated in your target environment.
- The tutorial materials are still useful for concepts, but the implementation has evolved. See docs/tutorial-drift.md for the current differences.